diff --git a/code/xtuner/.DS_Store b/code/xtuner/.DS_Store
deleted file mode 100644
index fa557d1ea4c50262c092c5b0b0430e1724c60d5f..0000000000000000000000000000000000000000
Binary files a/code/xtuner/.DS_Store and /dev/null differ
diff --git a/code/xtuner/__init__.py b/code/xtuner/__init__.py
deleted file mode 100644
index cb1d94302bdd08088746432918edccd3a306d874..0000000000000000000000000000000000000000
--- a/code/xtuner/__init__.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-
-from mmengine.utils import digit_version
-
-from .entry_point import cli
-from .version import __version__, version_info
-
-HF_CEPH_HUB = os.getenv('HF_CEPH_HUB', '')
-HF_USE_CEPH = os.getenv('HF_USE_CEPH', 0) or HF_CEPH_HUB != ''
-DS_CEPH_DIR = os.getenv('DS_CEPH_DIR', None)
-if HF_USE_CEPH:
- from .utils.fileio import (patch_hf_auto_from_pretrained,
- patch_hf_save_pretrained)
- patch_hf_auto_from_pretrained(HF_CEPH_HUB)
- patch_hf_save_pretrained()
-
-if DS_CEPH_DIR:
- from .utils.fileio import patch_deepspeed_engine
- patch_deepspeed_engine()
-
-__all__ = [
- '__version__', 'version_info', 'digit_version', 'cli', 'HF_USE_CEPH',
- 'DS_CEPH_DIR'
-]
diff --git a/code/xtuner/__pycache__/__init__.cpython-311.pyc b/code/xtuner/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index a1818d9244a0607d1e7bd710c505ffc0d3bde159..0000000000000000000000000000000000000000
Binary files a/code/xtuner/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/__pycache__/entry_point.cpython-311.pyc b/code/xtuner/__pycache__/entry_point.cpython-311.pyc
deleted file mode 100644
index 7820db4ec629c3cbbe30565701b055c53bc54c5a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/__pycache__/entry_point.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/__pycache__/registry.cpython-311.pyc b/code/xtuner/__pycache__/registry.cpython-311.pyc
deleted file mode 100644
index d0436cbe9c2ca2a124bb5a9e045cf8a4e235a1fe..0000000000000000000000000000000000000000
Binary files a/code/xtuner/__pycache__/registry.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/__pycache__/version.cpython-311.pyc b/code/xtuner/__pycache__/version.cpython-311.pyc
deleted file mode 100644
index d560259923346cdb8dd20bd3b11b639c0cab7988..0000000000000000000000000000000000000000
Binary files a/code/xtuner/__pycache__/version.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/_lite/.DS_Store b/code/xtuner/_lite/.DS_Store
deleted file mode 100644
index 629b6e6c7ef52dfe0777c50aa484a9f0de9910ab..0000000000000000000000000000000000000000
Binary files a/code/xtuner/_lite/.DS_Store and /dev/null differ
diff --git a/code/xtuner/_lite/__init__.py b/code/xtuner/_lite/__init__.py
deleted file mode 100644
index 513b0c574a5133fcbadd52765229be9648095810..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/__init__.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import subprocess
-import sys
-
-from loguru import logger
-
-from .device import get_device, get_torch_device_module
-
-_LOGGER = None
-
-
-def log_format(debug=False):
- formatter = "[XTuner][{time:YYYY-MM-DD HH:mm:ss}][{level}]"
-
- if debug:
- formatter += "[{name}:"
- formatter += "{function}:"
- formatter += "{line}]"
-
- formatter += " {message}"
- return formatter
-
-
-def get_logger(level="INFO"):
- global _LOGGER
- if _LOGGER is None:
- # Remove the original logger in Python to prevent duplicate printing.
- logger.remove()
- logger.add(sys.stderr, level=level, format=log_format(debug=level == "DEBUG"))
- _LOGGER = logger
- return _LOGGER
-
-
-def get_repo_git_info(repo_path):
- original_directory = os.getcwd()
- os.chdir(repo_path)
-
- try:
- branch = (
- subprocess.check_output(
- ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.STDOUT
- )
- .strip()
- .decode("utf-8")
- )
-
- commit_id = (
- subprocess.check_output(
- ["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT
- )
- .strip()
- .decode("utf-8")
- )
-
- remote_url = (
- subprocess.check_output(
- ["git", "remote", "get-url", "origin"], stderr=subprocess.STDOUT
- )
- .strip()
- .decode("utf-8")
- )
-
- return branch, commit_id, remote_url
- except subprocess.CalledProcessError:
- return None, None, None
- finally:
- os.chdir(original_directory)
-
-
-__all__ = [
- "AutoConfig",
- "AutoModelForCausalLM",
- "AutoTokenizer",
- "get_device",
- "get_torch_device_module",
-]
diff --git a/code/xtuner/_lite/accelerate/__init__.py b/code/xtuner/_lite/accelerate/__init__.py
deleted file mode 100644
index 0ea01cf5b0305d51d6390cd349e59ddf2e3b8ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .lora import LORA_TARGET_MAP
-from .packed import pack_sequence, unpack_sequence
-from .utils import (
- liger_kernel_is_available,
- lmdeploy_is_available,
- mlu_is_available,
- npu_is_available,
- profile_time_and_memory,
- varlen_attn_is_available,
-)
-
-__all__ = [
- "LORA_TARGET_MAP",
- "pack_sequence",
- "packed_sequence",
- "unpack_sequence",
- "liger_kernel_is_available",
- "varlen_attn_is_available",
- "lmdeploy_is_available",
- "npu_is_available",
- "mlu_is_available",
- "profile_time_and_memory",
-]
diff --git a/code/xtuner/_lite/accelerate/lora.py b/code/xtuner/_lite/accelerate/lora.py
deleted file mode 100644
index c473c2c235d921392c7ba49fe1239ea24db42668..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/lora.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-LORA_TARGET_MAP = {
- "InternLM2ForCausalLM": ["wqkv", "wo", "w1", "w2", "w3"],
- "CLIPVisionModel": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
-}
diff --git a/code/xtuner/_lite/accelerate/ops/__init__.py b/code/xtuner/_lite/accelerate/ops/__init__.py
deleted file mode 100644
index b992157a1eda43fa0e805a187b4032cc9154998c..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/ops/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .moe_permute import GROUPED_GEMM_INSTALLED, permute_func, unpermute_func
-
-__all__ = ["GROUPED_GEMM_INSTALLED", "permute_func", "unpermute_func"]
diff --git a/code/xtuner/_lite/accelerate/ops/moe_permute.py b/code/xtuner/_lite/accelerate/ops/moe_permute.py
deleted file mode 100644
index b6d4248c20393bfd3ee506dede718d0feb4c48a6..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/ops/moe_permute.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-"""Modified from
-https://github.com/fanshiqing/grouped_gemm/blob/v1.1.4/grouped_gemm/ops.py
-Support torch compile."""
-from typing import Optional, Tuple
-
-import torch
-from torch import Tensor
-
-GROUPED_GEMM_INSTALLED = False
-
-try:
- from grouped_gemm import backend
-
- GROUPED_GEMM_INSTALLED = True
-except ImportError:
- # install grouped gemm https://github.com/fanshiqing/grouped_gemm/tree/v1.1.4?tab=readme-ov-file#pip-install
- grouped_gmm = None
-
-
-@torch.library.custom_op("moe::permute", mutates_args=())
-def permute(input_act: Tensor, indices: Tensor, num_topK: int) -> Tuple[Tensor, Tensor]:
- input_max_expanded_token_num = input_act.size(0) * num_topK
- workspace_fw = []
- permuted_act, row_id_map, _ = backend.permute(
- input_act, indices, 0, workspace_fw, input_max_expanded_token_num
- )
- return permuted_act, row_id_map
-
-
-@permute.register_fake
-def permute_fake(
- input_act: Tensor,
- indices: Tensor,
- num_topK: int,
-):
- permuted_act = input_act.new_empty(
- (input_act.shape[0] * num_topK, *input_act.shape[1:])
- )
- row_id_map = indices.new_empty((indices.numel(),))
- return permuted_act, row_id_map
-
-
-@torch.library.custom_op("moe::unpermute", mutates_args=())
-def unpermute(
- input: Tensor, row_id_map: Tensor, prob: Tensor, max_tokens: int, num_topK: int
-) -> Tensor:
- if not input.is_contiguous():
- input = input.contiguous()
- return backend.unpermute(input, row_id_map, prob, max_tokens, num_topK)
-
-
-@unpermute.register_fake
-def unpermute_fake(
- input: Tensor, row_id_map: Tensor, prob: Tensor, max_tokens: int, num_topK: int
-) -> Tensor:
- return input.new_empty((input.shape[0] // num_topK, *input.shape[1:]))
-
-
-@torch.library.custom_op("moe::unpermute_bwd", mutates_args=())
-def unpermute_bwd(
- input_bwd: Tensor,
- input_fwd: Tensor,
- row_id_map: Tensor,
- prob: Optional[Tensor],
-) -> Tuple[Tensor, Tensor]:
- if not input_bwd.is_contiguous():
- input_bwd = input_bwd.contiguous()
- topk = input_fwd.shape[0] // input_bwd.shape[0]
- if prob is None:
- prob = torch.ones(
- [input_bwd.size(0), topk], dtype=torch.float32, device=input_bwd.device
- )
- return backend.unpermute_bwd(input_bwd, input_fwd, row_id_map, prob)
-
-
-@unpermute_bwd.register_fake
-def unpermute_bwd_fake(
- input_bwd: Tensor,
- input_fwd: Tensor,
- row_id_map: Tensor,
- prob: Optional[Tensor],
-) -> Tuple[Tensor, Tensor]:
- act_grad = torch.empty_like(input_fwd)
- topk = input_fwd.shape[0] // input_bwd.shape[0]
- prob_grad = torch.empty(
- (input_bwd.size(0), topk), dtype=torch.float32, device=input_bwd.device
- )
- return act_grad, prob_grad
-
-
-if torch.__version__ >= "2.4.0":
- _wrapped_permute = torch.ops.moe.permute
- _wrapped_unpermute = torch.ops.moe.unpermute
- _wrapped_unpermute_bwd = torch.ops.moe.unpermute_bwd
-else:
- _wrapped_permute = permute
- _wrapped_unpermute = unpermute
- _wrapped_unpermute_bwd = unpermute_bwd
-
-
-class PermuteMoE_topK(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- input_act: Tensor,
- indices: Tensor,
- ):
- if not input_act.numel():
- return input_act, None
-
- if indices.dim() == 1:
- indices = indices.view(-1, 1)
- if not input_act.is_contiguous():
- input_act = input_act.contiguous()
- if not indices.is_contiguous():
- indices = indices.contiguous()
-
- num_topK = indices.size(1)
-
- permuted_act, row_id_map = _wrapped_permute(
- input_act,
- indices,
- num_topK,
- )
-
- ctx.row_id_map = row_id_map
- ctx.num_tokens = indices.size(0)
- ctx.num_topK = num_topK
- return permuted_act, row_id_map
-
- @staticmethod
- def backward(ctx, permuted_act_grad, *args):
- if not permuted_act_grad.numel():
- return permuted_act_grad, None
-
- permuted_act_grad = permuted_act_grad.contiguous()
-
- row_id_map = ctx.row_id_map
- num_tokens = ctx.num_tokens
- num_topK = ctx.num_topK
-
- unpermuted_act_grad = _wrapped_unpermute(
- permuted_act_grad, row_id_map, torch.tensor([]), num_tokens, num_topK
- )
- return unpermuted_act_grad, None
-
-
-class UnpermuteMoE_topK(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input_act: Tensor, row_id_map: Tensor, probs: Tensor = None):
- if not input_act.numel():
- ctx.probs = probs
- return input_act
-
- if not input_act.is_contiguous():
- input_act = input_act.contiguous()
- if not row_id_map.is_contiguous():
- row_id_map = row_id_map.contiguous()
- if probs is not None and not probs.is_contiguous():
- probs = probs.contiguous()
-
- num_tokens = probs.size(0) if probs is not None else input_act.size(0)
- num_topK = probs.size(1) if probs is not None else 1
-
- unpermuted_output = _wrapped_unpermute(
- input_act,
- row_id_map,
- probs if probs is not None else torch.tensor([]),
- num_tokens,
- num_topK,
- )
-
- ctx.save_for_backward(input_act, row_id_map, probs)
- return unpermuted_output
-
- @staticmethod
- def backward(ctx, unpermuted_act_grad):
- if not unpermuted_act_grad.numel():
- return unpermuted_act_grad, None, ctx.probs
-
- input_act, row_id_map, probs = ctx.saved_tensors
-
- act_grad = None
- if ctx.needs_input_grad[0]:
- act_grad, prob_grad = _wrapped_unpermute_bwd(
- unpermuted_act_grad, input_act, row_id_map, probs
- )
-
- if not ctx.needs_input_grad[2]:
- prob_grad = None
- return act_grad, None, prob_grad
-
-
-def permute_func(input_act, indices):
- return PermuteMoE_topK.apply(input_act, indices)
-
-
-def unpermute_func(input_act, row_id_map, probs=None):
- return UnpermuteMoE_topK.apply(input_act, row_id_map, probs)
diff --git a/code/xtuner/_lite/accelerate/packed.py b/code/xtuner/_lite/accelerate/packed.py
deleted file mode 100644
index 3351eed9a8f4df2c4cedc7a5a5528894347eabfe..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/packed.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List, Union
-
-import torch
-
-
-def unpack_sequence(packed: torch.Tensor, num_tokens: Union[torch.Tensor, List], dim=1):
- if isinstance(num_tokens, torch.Tensor):
- num_tokens = num_tokens.tolist()
- sequences = torch.split(packed, num_tokens, dim=dim)
- return sequences
-
-
-def pack_sequence(sequences, dim=1):
- num_tokens = torch.IntTensor([seq.size(dim) for seq in sequences])
- packed = torch.cat(sequences, dim=dim)
- return packed, num_tokens.to(packed.device)
-
-
-def packed_cumulative_length(num_tokens: torch.Tensor):
- device = num_tokens.device
- _zero_pad = torch.zeros(1, device=device)
- _pad_length = torch.cat([_zero_pad, num_tokens]).int()
- return torch.cumsum(_pad_length, 0).int()
diff --git a/code/xtuner/_lite/accelerate/utils.py b/code/xtuner/_lite/accelerate/utils.py
deleted file mode 100644
index 8caa248abcc61a79616d2b5a519b89a0684a1cce..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/accelerate/utils.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import time
-from contextlib import contextmanager
-
-from transformers.utils.import_utils import is_flash_attn_2_available
-
-from xtuner._lite import get_device, get_logger, get_torch_device_module
-
-logger = get_logger()
-
-
-def npu_is_available():
- return get_device() == "npu"
-
-
-def mlu_is_available():
- return get_device() == "mlu"
-
-
-def varlen_attn_is_available():
- return is_flash_attn_2_available() or npu_is_available()
-
-
-def lmdeploy_is_available():
- available = False
- try:
- import lmdeploy # noqa: F401
-
- available = True
- except ImportError:
- available = False
-
- return available
-
-
-def liger_kernel_is_available():
- available = False
- try:
- import liger_kernel # noqa: F401
-
- available = True
- except ImportError:
- available = False
-
- return available
-
-
-@contextmanager
-def profile_time_and_memory(desc):
- torch_device = get_torch_device_module()
- start_t = time.time()
- torch_device.reset_peak_memory_stats()
-
- yield
-
- max_memory = torch_device.max_memory_allocated()
- cost_time = time.time() - start_t
-
- logger.success(
- f"{desc} Elapsed time {cost_time:.2f} seconds, "
- f"peak gpu memory {max_memory/1024**3:.1f}G"
- )
diff --git a/code/xtuner/_lite/algorithms/.DS_Store b/code/xtuner/_lite/algorithms/.DS_Store
deleted file mode 100644
index ad4a3ba4b58b587e784ec2169b2873fc8274f12c..0000000000000000000000000000000000000000
Binary files a/code/xtuner/_lite/algorithms/.DS_Store and /dev/null differ
diff --git a/code/xtuner/_lite/algorithms/__init__.py b/code/xtuner/_lite/algorithms/__init__.py
deleted file mode 100644
index ef101fec61e72abc0eb90266d453b5b22331378d..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/code/xtuner/_lite/algorithms/ppo/__init__.py b/code/xtuner/_lite/algorithms/ppo/__init__.py
deleted file mode 100644
index e091226a8ea1086da9e99c5688e031f5e6e4d327..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/ppo/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .dataset import (
- InferDataset,
- PPOTokenizeFunction,
- RewardBuffer,
- RewardBufferCollator,
-)
-from .loss import (
- CriticLoss,
- PPOPolicyLoss,
- compute_advantages_and_returns,
- compute_kl_rewards,
- gather_logprobs,
-)
-from .model import build_actor_model, build_reward_model
-
-__all__ = [
- "InferDataset",
- "RewardBuffer",
- "RewardBufferCollator",
- "PPOCollator",
- "PPODataset",
- "PPOTokenizeFunction",
- "CriticLoss",
- "PPOPolicyLoss",
- "compute_advantages_and_returns",
- "compute_kl_rewards",
- "compute_rewards",
- "gather_logprobs",
- "build_actor_model",
- "build_reward_model",
-]
diff --git a/code/xtuner/_lite/algorithms/ppo/dataset.py b/code/xtuner/_lite/algorithms/ppo/dataset.py
deleted file mode 100644
index d83bc694b8c985bddcac191dfaa288783a1a9988..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/ppo/dataset.py
+++ /dev/null
@@ -1,153 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import json
-
-import numpy as np
-import torch
-from torch import nn
-
-from xtuner._lite.chat.messages.chat import ChatMsg
-from xtuner._lite.datasets import OPENAI_CONVERT_MAP
-
-from ..sft import SftCollator, SftTokenizeFunction
-
-
-class InferDataset(torch.utils.data.Dataset):
- def __init__(self, prompts, responses):
- super().__init__()
-
- assert len(prompts) == len(responses)
- self.prompts = prompts
- self.responses = responses
- self.policies = None
-
- def __len__(self):
- return len(self.prompts)
-
- def __getitem__(self, item):
- prompt = self.prompts[item]
- response = self.responses[item]
- num_prefill_tokens = len(prompt)
-
- input_ids = prompt + response
- labels = [-100] * (num_prefill_tokens - 1) + response + [-100]
-
- return {"input_ids": input_ids, "labels": labels, "num_tokens": len(input_ids)}
-
-
-FASTER = False
-
-
-class RewardBuffer(torch.utils.data.Dataset):
- def __init__(self, clip_min=-5, clip_max=5, normalize=True, faster=False):
- super().__init__()
-
- self.clip_min = clip_min
- self.clip_max = clip_max
-
- self.normalize = normalize
-
- if self.normalize:
- self.bn = nn.BatchNorm1d(1, momentum=None, affine=False)
- else:
- self.bn = None
-
- self._num_action_tokens = 0
- self._num_total_tokens = 0
- self._trajectories = []
-
- self._current_mean = 0
-
- @property
- def running_mean(self):
- return self.bn.running_mean.item()
-
- @property
- def current_mean(self):
- return self._current_mean
-
- @property
- def num_action_tokens(self):
- return self._num_action_tokens.item()
-
- @property
- def num_total_tokens(self):
- return self._num_total_tokens
-
- def update(self, trajectories):
- rewards = [data["reward"] for data in trajectories]
-
- for i in range(len(trajectories)):
- trajectories[i]["ori_reward"] = trajectories[i]["reward"]
-
- rewards = torch.tensor(rewards)
-
- self._current_mean = rewards.mean().item()
-
- rewards = rewards.clip(self.clip_min, self.clip_max)
-
- if self.normalize:
- self.bn.train()
- _ = self.bn(rewards.unsqueeze(-1))
- self.bn.eval()
- rewards = self.bn(rewards.unsqueeze(-1))
-
- for i in range(len(trajectories)):
- trajectories[i]["reward"] = rewards[i].item()
-
- num_total_tokens = 0
- num_action_tokens = 0
- for data in trajectories:
- labels = np.array(data["labels"])
- num_total_tokens += labels.size
- num_action_tokens += (labels >= 0).sum()
-
- self._num_action_tokens = num_action_tokens
- self._num_total_tokens = num_total_tokens
-
- self._trajectories = trajectories
-
- def dump_jsonl(self, path, tokenizer, debug=False):
- with open(path, "w", encoding="utf8") as f:
- for data in self._trajectories:
- json_line = {
- "num_tokens": data["num_tokens"],
- "reward": data["ori_reward"],
- "sequence": tokenizer.decode(data["input_ids"]),
- }
-
- if debug:
- json_line["input_ids"] = data["input_ids"]
- json_line["labels"] = data["labels"]
-
- json_str = json.dumps(json_line, ensure_ascii=False)
- f.write(json_str + "\n")
-
- def __len__(self):
- return len(self._trajectories)
-
- def __getitem__(self, item):
- return self._trajectories[item]
-
-
-class PPOTokenizeFunction(SftTokenizeFunction):
- def __init__(self, tokenizer, chat_template, raw_format="openai", sys_prompt=None):
- super().__init__(tokenizer, chat_template, raw_format)
- self.sys_prompt = sys_prompt
-
- def __call__(self, item):
- formatter = OPENAI_CONVERT_MAP[self.raw_format]
- msg = formatter(item)
- if self.sys_prompt is not None:
- sys_msg = ChatMsg(role="system", content=self.sys_prompt)
- msg.messages = [sys_msg] + msg.messages
- tokenized = msg.tokenize(self.tokenizer, self.chat_template)
-
- return tokenized
-
-
-class RewardBufferCollator(SftCollator):
- def __call__(self, instances):
- data = super().__call__(instances)
- data["rewards"] = [item["reward"] for item in instances]
-
- return data
diff --git a/code/xtuner/_lite/algorithms/ppo/loss.py b/code/xtuner/_lite/algorithms/ppo/loss.py
deleted file mode 100644
index 71f1f2b7b73664dc91f6f899a0418ea0581bb52b..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/ppo/loss.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from torch.nn import functional as F
-
-from xtuner._lite import get_logger
-
-logger = get_logger()
-
-
-def gather_logprobs(logits, labels):
- log_probs = F.log_softmax(logits, dim=-1)
- log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
- return log_probs_labels.squeeze(-1)
-
-
-@torch.no_grad()
-def compute_kl_rewards(logprobs, ref_logprobs, reward_score, kl_coef=0.01):
- assert logprobs.ndim == 1
- last_mask = torch.zeros_like(logprobs, dtype=torch.int)
- last_mask[-1] = 1
-
- kl = ref_logprobs - logprobs
- kl_reward = kl_coef * kl * (1 - last_mask)
-
- last_reward = reward_score * last_mask
-
- rewards = kl_reward + last_reward
-
- return rewards
-
-
-@torch.no_grad()
-def compute_advantages_and_returns(values, rewards, gamma=1.0, gae_lambda=0.99):
- # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501
- """Function that computes advantages and returns from rewards and values.
- Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
- Note that rewards may include a KL divergence loss term.
-
- Advantages looks like this:
- Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
- - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
-
- Returns looks like this:
- Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ...
- + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...
- """
- lastgaelam = 0
- advantages_reversed = []
-
- assert values.numel() == rewards.numel(), f"{values.numel()}, {rewards.numel()}"
- length = rewards.numel()
-
- for t in reversed(range(0, length)):
- nextvalues = values[t + 1] if t < length - 1 else 0.0
- # Since old_rewards and old_values are masked with action_mask,
- # i.e. they have 0's at pad tokens,
- # delta will be 0 if current t is at a pad token,
- # so will lastgaelam
- delta = rewards[t] + gamma * nextvalues - values[t]
- lastgaelam = delta + gamma * gae_lambda * lastgaelam
- advantages_reversed.append(lastgaelam)
-
- advantages = torch.stack(advantages_reversed[::-1], dim=0)
- returns = advantages + values
- return advantages.detach(), returns
-
-
-class CriticLoss(torch.nn.Module):
- """Loss function for critic model."""
-
- def __init__(self, cliprange_value: float = 0.5, loss_type: str = "per_seq"):
- super().__init__()
- self.cliprange_value = cliprange_value
- self.loss_type = loss_type
-
- assert self.loss_type in ["per_token", "per_seq"]
-
- def critic_loss_fn(self, values, old_values, returns, loss_factor=None):
- values_clipped = old_values + (values - old_values).clamp(
- -self.cliprange_value, self.cliprange_value
- )
- vf_loss1 = (values_clipped - returns) ** 2
- vf_loss2 = (values - returns) ** 2
- if self.loss_type == "per_seq":
- vf_loss = torch.max(vf_loss1, vf_loss2).mean(-1)
- elif self.loss_type == "per_token":
- assert loss_factor is not None
- vf_loss = torch.sum(torch.max(vf_loss1, vf_loss2) * loss_factor)
- return 0.5 * vf_loss
-
- def forward(self, values: torch.Tensor, old_values, returns, loss_factor=None):
- loss = self.critic_loss_fn(
- values=values,
- old_values=old_values,
- returns=returns,
- loss_factor=loss_factor,
- )
- return loss
-
-
-class PPOPolicyLoss(torch.nn.Module):
- """Loss function for policy model."""
-
- def __init__(self, cliprange: float = 0.2, loss_type: str = "per_seq"):
- super().__init__()
- self.cliprange = cliprange
- self.loss_type = loss_type
- assert self.loss_type in ["per_token", "per_seq"]
-
- def forward(self, logprobs, old_logprobs, advantages, loss_factor=None):
- ratio = (logprobs - old_logprobs).exp()
- pg_loss1 = -ratio * advantages
- pg_loss2 = -ratio.clamp(1 - self.cliprange, 1 + self.cliprange) * advantages
- if self.loss_type == "per_seq":
- pg_loss = torch.max(pg_loss1, pg_loss2).mean(dim=-1)
- elif self.loss_type == "per_token":
- assert loss_factor is not None
- pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2)) * loss_factor
- return pg_loss
diff --git a/code/xtuner/_lite/algorithms/ppo/model.py b/code/xtuner/_lite/algorithms/ppo/model.py
deleted file mode 100644
index fefc4d2ed04485498ebc6737eeb7829179a3745f..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/ppo/model.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
-from transformers.utils.import_utils import (
- is_flash_attn_2_available,
- is_torch_sdpa_available,
-)
-
-from xtuner._lite.accelerate import LoadWoInit
-
-
-def build_actor_model(model_path, dtype=torch.float32, trust_remote_code=True):
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- if is_flash_attn_2_available():
- config.attn_implementation = "flash_attention_2"
- elif is_torch_sdpa_available():
- config.attn_implementation = "sdpa"
-
- with LoadWoInit():
- policy = AutoModelForCausalLM.from_pretrained(
- model_path,
- attn_implementation="flash_attention_2",
- torch_dtype=dtype,
- trust_remote_code=trust_remote_code,
- )
-
- return policy
-
-
-def build_reward_model(model_path, dtype=torch.float32, trust_remote_code=True):
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- if is_flash_attn_2_available():
- config.attn_implementation = "flash_attention_2"
- elif is_torch_sdpa_available():
- config.attn_implementation = "sdpa"
-
- config.use_cache = False
- config.torch_dtype = dtype
- with LoadWoInit():
- reward = AutoModel.from_pretrained(
- model_path,
- attn_implementation="flash_attention_2",
- torch_dtype=dtype,
- trust_remote_code=trust_remote_code,
- )
-
- reward.model.use_cache = False
-
- return reward
diff --git a/code/xtuner/_lite/algorithms/sft/__init__.py b/code/xtuner/_lite/algorithms/sft/__init__.py
deleted file mode 100644
index 45fe845c5f456a15e3a67c0157b4912dc02ecf1a..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/sft/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .dataset import SftCollator, SftTokenizeFunction
-
-__all__ = ["SftCollator", "SftTokenizeFunction"]
diff --git a/code/xtuner/_lite/algorithms/sft/dataset.py b/code/xtuner/_lite/algorithms/sft/dataset.py
deleted file mode 100644
index 8986c89efefb169489559da6f8a2492eb2a0bc82..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/algorithms/sft/dataset.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-from xtuner._lite import get_logger
-from xtuner._lite.datasets import OPENAI_CONVERT_MAP
-
-logger = get_logger()
-
-
-class SftTokenizeFunction:
- def __init__(self, tokenizer, chat_template, raw_format="openai"):
- self.tokenizer = tokenizer
- self.chat_template = chat_template
- self.raw_format = raw_format
-
- def __call__(self, item):
- formatter = OPENAI_CONVERT_MAP[self.raw_format]
- msg = formatter(item)
- tokenized = msg.tokenize(self.tokenizer, self.chat_template)
- return tokenized
-
-
-class SftCollator:
- def __init__(
- self, pad_token_id=0, ignore_id=-100, pack_batch=False, max_length=None
- ):
- self.pack_batch = pack_batch
- self.pad_token_id = pad_token_id
- self.ignore_id = ignore_id
- self.max_length = max_length
-
- def __call__(self, instances):
- _instances = []
- for ins in instances:
- if isinstance(ins, list):
- _instances.extend(ins)
- else:
- _instances.append(ins)
-
- instances = _instances
-
- input_ids = []
- labels = []
- num_tokens = []
-
- for data in instances:
- _input_ids = data["input_ids"]
- _labels = data["labels"]
- _num_tokens = data["num_tokens"]
-
- # TODO remove list
- if isinstance(_num_tokens, list):
- assert len(_num_tokens) == 1
- _num_tokens = _num_tokens[0]
-
- assert isinstance(_num_tokens, int)
-
- if self.max_length:
- _input_ids = _input_ids[: self.max_length]
- _labels = _labels[: self.max_length]
- _num_tokens = min(_num_tokens, self.max_length)
-
- input_ids.append(torch.LongTensor(_input_ids))
- labels.append(torch.LongTensor(_labels))
- num_tokens.append(_num_tokens)
-
- attention_mask = [torch.ones_like(ids) for ids in input_ids]
- num_tokens = torch.IntTensor(num_tokens)
-
- if len(instances) > 1 and self.pack_batch:
- input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
- labels = torch.cat(labels, dim=0).unsqueeze(0)
- attention_mask = torch.cat(attention_mask, dim=0).unsqueeze(0)
-
- elif len(instances) > 1 and not self.pack_batch:
- input_ids = pad_sequence(
- input_ids, batch_first=True, padding_value=self.pad_token_id
- )
- labels = pad_sequence(
- labels, batch_first=True, padding_value=self.ignore_id
- )
- attention_mask = pad_sequence(
- attention_mask, batch_first=True, padding_value=0
- )
- else:
- input_ids = torch.stack(input_ids)
- labels = torch.stack(labels)
- attention_mask = torch.stack(attention_mask)
-
- if input_ids.shape != labels.shape:
- logger.error(f"[instances] {instances}")
- logger.error(f"[num_tokens] {num_tokens}")
- logger.error(f"[input_ids] {input_ids}")
- logger.error(f"[labels] {labels}")
- raise RuntimeError(
- "The shape of input_ids and labels must be "
- f"equal, but found {input_ids.shape} and "
- f"{labels.shape}."
- )
-
- data_dict = {
- "input_ids": input_ids,
- "labels": labels,
- "num_tokens": num_tokens,
- "attention_mask": attention_mask.bool(),
- }
-
- return data_dict
diff --git a/code/xtuner/_lite/chat/.DS_Store b/code/xtuner/_lite/chat/.DS_Store
deleted file mode 100644
index 2859483dee4d7c8b06db32dbbec56a79888c9c1b..0000000000000000000000000000000000000000
Binary files a/code/xtuner/_lite/chat/.DS_Store and /dev/null differ
diff --git a/code/xtuner/_lite/chat/__init__.py b/code/xtuner/_lite/chat/__init__.py
deleted file mode 100644
index 1eb033d0edcf63e0b85061371c919b27576fe95a..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .messages import ChatMessages
-from .templates import CHAT_TEMPLATE_MAP, ChatTemplate, HybridChatTemplate
-
-__all__ = ["ChatMessages", "CHAT_TEMPLATE_MAP", "ChatTemplate", "HybridChatTemplate"]
diff --git a/code/xtuner/_lite/chat/backends/__init__.py b/code/xtuner/_lite/chat/backends/__init__.py
deleted file mode 100644
index ef101fec61e72abc0eb90266d453b5b22331378d..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/backends/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/code/xtuner/_lite/chat/messages/__init__.py b/code/xtuner/_lite/chat/messages/__init__.py
deleted file mode 100644
index bc3ced2dc22bf3fa00fdbefea1a4f7598d5786f8..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/messages/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .base import BaseMessages
-from .chat import ChatMessages
-
-__all__ = ["BaseMessages", "ChatMessages"]
diff --git a/code/xtuner/_lite/chat/messages/base.py b/code/xtuner/_lite/chat/messages/base.py
deleted file mode 100644
index dc6b0615e815477a3a80db7cb766b7f6430f8c4c..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/messages/base.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from abc import abstractclassmethod, abstractmethod
-from typing import Dict
-
-from pydantic import BaseModel
-from transformers import PreTrainedTokenizer
-
-from ..templates import ChatTemplate
-
-
-class BaseMessages(BaseModel):
- @abstractmethod
- def add(self, role: str, content):
- pass
-
- @abstractmethod
- def pop(self):
- pass
-
- @abstractmethod
- def get_prompt(self, chat_template: ChatTemplate) -> str:
- pass
-
- @abstractmethod
- def tokenize(
- self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate
- ) -> Dict:
- pass
-
- @abstractclassmethod
- def from_dict(cls, item: Dict) -> "BaseMessages":
- pass
diff --git a/code/xtuner/_lite/chat/messages/chat.py b/code/xtuner/_lite/chat/messages/chat.py
deleted file mode 100644
index 4af88c3e4491688836cad46b1152158da2f22323..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/messages/chat.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import copy
-from typing import Dict, List, Literal, Optional, Union
-
-from pydantic import BaseModel
-from transformers import PreTrainedTokenizer
-
-from xtuner._lite import get_logger
-from xtuner.utils import IGNORE_INDEX
-
-from ..templates import ChatTemplate, HybridChatTemplate
-from .base import BaseMessages
-
-logger = get_logger()
-
-
-class TextContentItem(BaseModel):
- type: Literal["text"] = "text"
- text: str
-
- def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
- return self.text
-
-
-class ImageContentItem(BaseModel):
- type: Literal["image_url"] = "image_url"
- image_url: str
-
- def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
- return chat_template.image_token
-
-
-MultModalContentType = Union[TextContentItem, ImageContentItem]
-ContentType = Union[str, List[MultModalContentType]]
-
-
-class ChatMsg(BaseModel):
- role: Literal["assistant", "user", "system"]
- content: ContentType
- loss: Optional[bool] = None
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- if self.loss is None:
- if self.role == "system":
- self.loss = False
- elif self.role == "user":
- self.loss = False
- elif self.role == "assistant":
- self.loss = True
- else:
- raise NotImplementedError
-
- def collect_img_urls(self) -> List[str]:
- img_urls = []
- if isinstance(self.content, list):
- for item in self.content:
- if isinstance(item, ImageContentItem):
- img_urls.append(item.image_url)
- return img_urls
-
- def get_prompt(self, chat_template: ChatTemplate) -> str:
- if isinstance(self.content, str):
- text = self.content
- elif isinstance(self.content, list):
- text = ""
- for i, item in enumerate(self.content):
- if i == 0:
- text += item.apply_chat_template(chat_template)
- else:
- text += "\n" + item.apply_chat_template(chat_template)
- else:
- raise NotImplementedError
-
- if self.role == "system":
- prompt = chat_template.decorate_system(text)
- elif self.role == "user":
- prompt = chat_template.decorate_user(text)
- elif self.role == "assistant":
- prompt = chat_template.decorate_assistant(text)
- else:
- raise NotImplementedError
-
- return prompt
-
- def tokenize(
- self,
- tokenizer: PreTrainedTokenizer,
- chat_template: ChatTemplate,
- ):
- decorated = self.get_prompt(chat_template)
-
- token_ids = tokenizer.encode(decorated, add_special_tokens=False)
-
- if self.loss:
- label_ids = copy.deepcopy(token_ids)
- else:
- label_ids = [IGNORE_INDEX] * len(token_ids)
-
- return {
- "input_ids": token_ids,
- "labels": label_ids,
- }
-
-
-class ChatMessages(BaseMessages):
- messages: List[ChatMsg]
-
- def add(self, role, content, loss=False):
- self.messages.append(ChatMsg(role=role, content=content, loss=loss))
-
- def pop(self):
- return self.messages.pop()
-
- def get_prompt(self, chat_template: ChatTemplate) -> str:
- prompt = ""
-
- for msg in self.messages:
- prompt += msg.get_prompt(chat_template)
- if msg.role == "assistant":
- prompt += chat_template.sep
- return prompt
-
- def tokenize(
- self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate
- ) -> Dict:
- input_ids = tokenizer.encode("", add_special_tokens=True)
- labels = [IGNORE_INDEX for _ in input_ids]
- image_urls = []
-
- for msg in self.messages:
- res = msg.tokenize(tokenizer, chat_template)
- token_ids, label_ids = res["input_ids"], res["labels"]
-
- input_ids.extend(token_ids)
- labels.extend(label_ids)
-
- image_urls.extend(msg.collect_img_urls())
-
- if msg.role == "assistant":
- sep = chat_template.sep
- sep_tokens = tokenizer.encode(sep, add_special_tokens=False)
- input_ids.extend(sep_tokens)
- labels.extend([IGNORE_INDEX] * len(sep_tokens))
-
- if len(input_ids) != len(labels):
- logger.error(f"[messages] {self.messages}")
- logger.error(f"[input_ids] {input_ids}")
- logger.error(f"[labels] {labels}")
- raise RuntimeError(
- "The lengths of input_ids and labels must be "
- f"equal, but found {len(input_ids)} and "
- f"{len(labels)}."
- )
-
- training_data = {
- "input_ids": input_ids,
- "labels": labels,
- "num_tokens": len(input_ids),
- }
-
- if len(image_urls) > 0:
- training_data["image_urls"] = image_urls
-
- return training_data
-
- @classmethod
- def from_str(cls, prompt: str) -> "ChatMessages":
- msg = ChatMsg(role="user", content=prompt)
- return cls(messages=[msg])
-
- @classmethod
- def from_dict(cls, item: dict) -> "ChatMessages":
- """
- item
- {
- 'messages':[
- {'role':'user', 'content':'hello'},
- {'role':'assistant', 'content':'hello!'},
- ],
- }
- """
- return cls(**item)
-
-
-if __name__ == "__main__":
- data = {
- "messages": [
- {"role": "user", "content": "hello"},
- {"role": "assistant", "content": "hello!"},
- ]
- }
-
- messages = ChatMessages.from_dict(data)
- chat_template = ChatTemplate(
- system="<|im_start|>system\n{system}<|im_end|>\n",
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
- assistant="{assistant}<|im_end|>\n",
- stop_words=["<|im_end|>"],
- )
-
- print(messages.get_prompt(chat_template))
diff --git a/code/xtuner/_lite/chat/templates/__init__.py b/code/xtuner/_lite/chat/templates/__init__.py
deleted file mode 100644
index 4d82700a0985e276b1c39007937674620df45ee4..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/templates/__init__.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .chat import ChatTemplate
-from .hybrid import HybridChatTemplate
-
-CHAT_TEMPLATE_MAP = {
- "internlm2": HybridChatTemplate(
- system="<|im_start|>system\n{system}<|im_end|>\n",
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
- assistant="{assistant}<|im_end|>",
- stop_words=["<|im_end|>"],
- ),
- "qwen2": HybridChatTemplate(
- system="<|im_start|>system\n{system}<|im_end|>\n",
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
- assistant="{assistant}<|im_end|>",
- stop_words=["<|im_end|>", "<|endoftext|>"],
- ),
- "llama3": HybridChatTemplate(
- system=("<|start_header_id|>system<|end_header_id|>\n\n{system}" "<|eot_id|>"),
- user=(
- "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>"
- "<|start_header_id|>assistant<|end_header_id|>\n\n"
- ),
- assistant="{assistant}<|eot_id|>",
- sep="",
- stop_words=["<|eot_id|>"],
- ),
-}
-
-__all__ = ["ChatTemplate", "HybridChatTemplate"]
diff --git a/code/xtuner/_lite/chat/templates/chat.py b/code/xtuner/_lite/chat/templates/chat.py
deleted file mode 100644
index 6b29e4108a8936c8f64104b8ab52f39b0cb3f7a2..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/templates/chat.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List
-
-from pydantic import BaseModel, field_validator
-
-
-class ChatTemplate(BaseModel):
- """Define a Pydantic data model for a hybrid chat with attributes for
- system, user and assistant chat as well as function and interpreter calls
- and results."""
-
- # Normal Chat
- system: str # System message format
- user: str # User message format
- assistant: str # Assistant message format
- stop_words: List[str] # List of stop words
- sep: str = "\n"
-
- def decorate_system(self, text: str) -> str:
- """Decorate text with the `system` template."""
- return self.system.format(system=text)
-
- def decorate_assistant(self, text: str) -> str:
- """Decorate text with the `assistant` template."""
- return self.assistant.format(assistant=text)
-
- def decorate_user(self, text: str) -> str:
- """Decorate text with the `user` template."""
- return self.user.format(user=text)
-
- @field_validator("system")
- def check_system(cls, v: str) -> str:
- """Validate that `system` contains '{system}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{system}" not in v:
- raise ValueError("system must contain the keyword '{system}'")
- return v
-
- @field_validator("user")
- def check_user(cls, v: str) -> str:
- """Validate that `user` contains '{user}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{user}" not in v:
- raise ValueError("user must contain the keyword '{user}'")
- return v
-
- @field_validator("assistant")
- def check_assistant(cls, v: str) -> str:
- """Validate that `assistant` contains '{assistant}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{assistant}" not in v:
- raise ValueError("assistant must contain the keyword '{assistant}'")
- return v
diff --git a/code/xtuner/_lite/chat/templates/hybrid.py b/code/xtuner/_lite/chat/templates/hybrid.py
deleted file mode 100644
index d4b7cbfddfd04ab1033c6b917223df5fc5165bc6..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/chat/templates/hybrid.py
+++ /dev/null
@@ -1,206 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, List, Optional
-
-from pydantic import BaseModel, field_validator
-
-
-class HybridChatTemplate(BaseModel):
- """Define a Pydantic data model for a hybrid chat with attributes for
- system, user and assistant chat as well as function and interpreter calls
- and results."""
-
- # Normal Chat
- system: str # System message format
- user: str # User message format
- assistant: str # Assistant message format
- stop_words: List[str] # List of stop words
- sep: str = "\n"
-
- # Multimodal Chat
- # Predefined token and index for images
- image_token: str = ""
- image_token_index: int = -100
-
- # Agent Chat
-
- # Interpreter and function related strings
- files: Optional[str] = None
-
- functions: Optional[str] = None # Function description format
- function_call: Optional[str] = None # Function call format
- function_result: Optional[str] = None # Function result format
-
- code_interpreter: Optional[str] = None
- code_interpreter_call: Optional[str] = None # Interpreter call format
- code_interpreter_result: Optional[str] = None # Interpreter result format
-
- function_token: Optional[str] = None
- code_interpreter_token: Optional[str] = None
- action_start_token: Optional[str] = None
- action_end_token: Optional[str] = None
-
- @property
- def mm_token_maps(self) -> Dict[str, int]:
- """Return a dictionary that maps multimodal tokens to corresponding
- token indexes."""
- return {self.image_token: self.image_token_index}
-
- def decorate_system(self, text: str) -> str:
- """Decorate text with the `system` template."""
- return self.system.format(system=text)
-
- def decorate_assistant(self, text: str) -> str:
- """Decorate text with the `assistant` template."""
- return self.assistant.format(assistant=text)
-
- def decorate_user(self, text: str) -> str:
- """Decorate text with the `user` template."""
- return self.user.format(user=text)
-
- def decorate_files(self, text: str) -> str:
- """Decorate text with the `functions` template."""
- return self.files.format(files=text)
-
- def decorate_functions(self, text: str) -> str:
- """Decorate text with the `functions` template."""
- return self.functions.format(functions=text)
-
- def decorate_function_call(self, text: str, func: str) -> str:
- """Decorate text with the `function_call` template."""
- return self.function_call.format(assistant=text, function_call=func)
-
- def decorate_function_result(self, text: str) -> str:
- """Decorate text with the `function_result` template."""
- return self.function_result.format(function_result=text)
-
- def decorate_code_interpreter(self, text: str) -> str:
- """Decorate text with the `code_interpreter` template."""
- return self.code_interpreter.format(code_interpreter=text)
-
- def decorate_code_interpreter_call(self, text: str, func: str) -> str:
- """Decorate text with the `code_interpreter_call` template."""
- return self.code_interpreter_call.format(
- assistant=text, code_interpreter_call=func
- )
-
- def decorate_code_interpreter_result(self, text: str) -> str:
- """Decorate text with the `code_interpreter_result` template."""
- return self.code_interpreter_result.format(code_interpreter_result=text)
-
- @field_validator("system")
- def check_system(cls, v: str) -> str:
- """Validate that `system` contains '{system}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{system}" not in v:
- raise ValueError("system must contain the keyword '{system}'")
- return v
-
- @field_validator("user")
- def check_user(cls, v: str) -> str:
- """Validate that `user` contains '{user}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{user}" not in v:
- raise ValueError("user must contain the keyword '{user}'")
- return v
-
- @field_validator("assistant")
- def check_assistant(cls, v: str) -> str:
- """Validate that `assistant` contains '{assistant}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{assistant}" not in v:
- raise ValueError("assistant must contain the keyword '{assistant}'")
- return v
-
- @field_validator("function_call")
- def check_function_call(cls, v: str) -> str:
- """Validate that `function_call` contains '{function_call}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{function_call}" not in v and "{assistant}" not in v:
- raise ValueError(
- "function_call must contain the keywords '{function_call}'"
- )
- if v is not None and "{assistant}" not in v:
- raise ValueError(
- "function_call must contain the keyword '{assistant}' and "
- "'{function_call}'"
- )
- return v
-
- @field_validator("function_result")
- def check_function_result(cls, v: str) -> str:
- """Validate that `function_result` contains '{function_result}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{function_result}" not in v:
- raise ValueError(
- "function_result must contain the keyword '{function_result}'"
- )
- return v
-
- @field_validator("functions")
- def check_functions(cls, v: str) -> str:
- """Validate that `functions` contains '{functions}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{functions}" not in v:
- raise ValueError("functions must contain the keyword '{functions}'")
- return v
-
- @field_validator("code_interpreter")
- def check_code_interpreter(cls, v: str) -> str:
- """Validate that `code_interpreter` contains '{code_interpreter}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{code_interpreter}" not in v:
- raise ValueError(
- "code_interpreter must contain the keyword " "'{code_interpreter}'"
- )
- return v
-
- @field_validator("code_interpreter_call")
- def check_code_interpreter_call(cls, v: str) -> str:
- """Validate that `code_interpreter_call` contains
- '{code_interpreter_call}'.
-
- If not, raises a ValueError.
- """
- if (
- v is not None
- and "{code_interpreter_call}" not in v
- and "{assistant}" not in v
- ):
- raise ValueError(
- "code_interpreter_call must contain the keywords "
- "'{assistant}' and '{code_interpreter_call}'"
- )
- if v is not None and "{assistant}" not in v:
- raise ValueError(
- "code_interpreter_call must contain the keywords "
- "'{assistant}' and '{code_interpreter_call}'"
- )
- return v
-
- @field_validator("code_interpreter_result")
- def check_code_interpreter_result(cls, v: str) -> str:
- """Validate that `code_interpreter_result` contains
- '{code_interpreter_result}'.
-
- If not, raises a ValueError.
- """
- if v is not None and "{code_interpreter_result}" not in v:
- raise ValueError(
- "code_interpreter_result must contain the keyword "
- "'{code_interpreter_result}'"
- )
- return v
diff --git a/code/xtuner/_lite/datasets/__init__.py b/code/xtuner/_lite/datasets/__init__.py
deleted file mode 100644
index b5b4b24bf8c9bec144dc7f92cc5dce9169ecdf46..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .json import JsonDataset
-from .jsonl import JsonlDataset
-from .pack import SoftPackDataset
-from .utils import DATASET_CLS_MAP, OPENAI_CONVERT_MAP, load_datasets
-
-__all__ = [
- "JsonDataset",
- "JsonlDataset",
- "SoftPackDataset",
- "DATASET_CLS_MAP",
- "OPENAI_CONVERT_MAP",
- "load_datasets",
-]
diff --git a/code/xtuner/_lite/datasets/json.py b/code/xtuner/_lite/datasets/json.py
deleted file mode 100644
index 08c3cb50b9ddf991f283ce30610ac2c4f0caa10f..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/json.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import hashlib
-import inspect
-import json
-import math
-import os
-import random
-from concurrent.futures import ProcessPoolExecutor
-
-import numpy as np
-import torch
-from mmengine import mkdir_or_exist
-from torch import distributed as dist
-from tqdm import tqdm
-
-from xtuner._lite import get_logger
-
-logger = get_logger()
-
-
-def calculate_json_sha256(file_path):
- with open(file_path, "rb") as f:
- data = f.read()
-
- hash_object = hashlib.sha256(data)
- hash_hex = hash_object.hexdigest()
- return hash_hex
-
-
-def calculate_tokenize_fn_sha256(tokenize_fn):
- """Calculate SHA-256 hash for an instance method's source code."""
- # Get the source code of the method
- fn_source = inspect.getsource(tokenize_fn.__call__)
- return hashlib.sha256(fn_source.encode("utf-8")).hexdigest()
-
-
-class JsonDataset(torch.utils.data.Dataset):
- def __init__(
- self, path, sample_ratio=1.0, tokenize_fn=None, cache_dir=None, max_length=None
- ):
- super().__init__()
-
- self.tokenize_fn = tokenize_fn
- self.path = path
- self.tokenizer_workers = int(os.environ.get("XTUNER_TOKENIZE_WORKERS", 8))
-
- if cache_dir:
- if os.path.exists(cache_dir):
- assert os.path.isdir(cache_dir)
- else:
- mkdir_or_exist(cache_dir)
-
- file_hash = calculate_json_sha256(path)
- file_cache_dir = os.path.join(cache_dir, file_hash)
-
- if file_hash not in os.listdir(cache_dir):
- mkdir_or_exist(file_cache_dir)
-
- if self.tokenize_fn:
- tok_hash = calculate_tokenize_fn_sha256(tokenize_fn)
- tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
- if tok_hash not in os.listdir(file_cache_dir):
- mkdir_or_exist(tok_cache_dir)
-
- if "num_tokens.npy" in os.listdir(tok_cache_dir):
- _cached_file = os.path.join(tok_cache_dir, "num_tokens.npy")
- num_tokens = np.load(_cached_file)
- else:
- num_tokens = self.count_tokens(tok_cache_dir)
- else:
- num_tokens = None
-
- else:
- num_tokens = None
-
- with open(self.path) as f:
- dataset = json.load(f)
-
- _sampled = [i for i in range(len(dataset))]
-
- if max_length is not None:
- assert isinstance(max_length, int)
- _filtered = [
- x for i, x in enumerate(_sampled) if num_tokens[i] < max_length
- ]
-
- if len(_filtered) < len(_sampled):
- missed_num = len(_sampled) - len(_filtered)
- logger.warning(
- f"{path} has {missed_num} prompt length>{max_length}, discard."
- )
-
- _sampled = _filtered
-
- _target_num_samples = int(len(_sampled) * sample_ratio)
- self.sampled = _sampled * int(sample_ratio)
- self.sampled.extend(
- random.sample(_sampled, _target_num_samples - len(self.sampled))
- )
-
- if num_tokens is not None:
- num_tokens = num_tokens[self.sampled]
-
- self.num_tokens = num_tokens
- self.dataset = None
-
- def count_tokens(self, cache_dir=None):
- dataset = []
-
- with open(self.path) as f:
- dataset = json.load(f)
-
- num_samples = len(dataset)
-
- if dist.is_available():
- world_size = dist.get_world_size()
- rank = dist.get_rank()
- else:
- world_size = 1
- rank = 0
-
- num_per_rank = math.ceil(num_samples / world_size)
-
- start = rank * num_per_rank
- end = (rank + 1) * num_per_rank
- dataset_shard = dataset[start:end]
-
- desc = f"[Rank {rank}] {self.path}"
- chunk_size = min(1024, max(1, len(dataset_shard) // self.tokenizer_workers))
- with ProcessPoolExecutor(max_workers=self.tokenizer_workers) as executor:
- tokenized = list(
- tqdm(
- executor.map(self.tokenize_fn, dataset_shard, chunksize=chunk_size),
- desc=desc,
- total=len(dataset_shard),
- )
- )
-
- _num_tokens = [data["num_tokens"] for data in tokenized]
- _num_tokens = np.array(_num_tokens)
-
- if dist.is_available():
- num_tokens = [None] * world_size
- dist.all_gather_object(num_tokens, _num_tokens)
- num_tokens = np.concatenate(num_tokens, axis=0)
- else:
- num_tokens = _num_tokens
-
- if rank == 0 and cache_dir:
- save_path = os.path.join(cache_dir, "num_tokens.npy")
- np.save(save_path, num_tokens)
-
- return num_tokens
-
- def __len__(self):
- return len(self.sampled)
-
- def __getitem__(self, item):
- """Returns a dict containing packed data in the given item.
-
- Args:
- item: An index to retrieve packed data.
-
- Returns:
- A dict including packed input_ids, labels, and cumulative_len.
- """
- if self.dataset is None:
- with open(self.path) as f:
- self.dataset = json.load(f)
-
- raw_data = self.dataset[self.sampled[item]]
-
- if self.tokenize_fn:
- tokenized_data = self.tokenize_fn(raw_data)
- return tokenized_data
- else:
- return raw_data
diff --git a/code/xtuner/_lite/datasets/jsonl.py b/code/xtuner/_lite/datasets/jsonl.py
deleted file mode 100644
index 75cdf8843f4ed780eeacb41eed3e6d5a42373249..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/jsonl.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import hashlib
-import json
-import math
-import multiprocessing
-import os
-import random
-from abc import ABC, abstractmethod
-from concurrent.futures import ProcessPoolExecutor
-from typing import Any, Callable, TypedDict
-
-import numpy as np
-import torch
-from mmengine import mkdir_or_exist
-from torch import distributed as dist
-from tqdm import tqdm
-
-from xtuner._lite import get_logger
-
-logger = get_logger()
-
-
-def calculate_jsonl_sha256(path):
- with open(path, "rb") as f:
- file_hash = hashlib.sha256()
- file_hash.update(f.read())
- return file_hash.hexdigest()
-
-
-CacheObj = TypedDict("CachedObj", {"num_tokens": int}, total=False)
-
-
-class CachableTokenizeFunction(ABC):
- @abstractmethod
- def __call__(self, item: Any) -> CacheObj:
- raise NotImplementedError
-
- @abstractmethod
- def hash(self) -> str:
- raise NotImplementedError
-
-
-class JsonlDataset(torch.utils.data.Dataset):
- def __init__(
- self,
- path,
- sample_ratio: float = 1.0,
- tokenize_fn: Callable[[Any], CacheObj] | None = None,
- cache_dir: str | None = None,
- max_length: int | None = None,
- ):
- super().__init__()
-
- self.tokenize_fn = tokenize_fn
- self.path = path
- self.tokenizer_workers = int(os.environ.get("XTUNER_TOKENIZE_WORKERS", 8))
-
- if cache_dir and isinstance(tokenize_fn, CachableTokenizeFunction):
- if os.path.exists(cache_dir):
- assert os.path.isdir(cache_dir)
- else:
- mkdir_or_exist(cache_dir)
-
- file_hash = calculate_jsonl_sha256(path)
- file_cache_dir = os.path.join(cache_dir, file_hash)
-
- if file_hash not in os.listdir(cache_dir):
- mkdir_or_exist(file_cache_dir)
-
- if "offsets.npy" in os.listdir(file_cache_dir):
- _cached_file = os.path.join(file_cache_dir, "offsets.npy")
- offsets = np.load(_cached_file)
- else:
- offsets = self.count_offsets(file_cache_dir)
-
- if self.tokenize_fn:
- tok_hash = tokenize_fn.hash()
- tok_cache_dir = os.path.join(file_cache_dir, tok_hash)
- if tok_hash not in os.listdir(file_cache_dir):
- mkdir_or_exist(tok_cache_dir)
-
- if "num_tokens.npy" in os.listdir(tok_cache_dir):
- _cached_file = os.path.join(tok_cache_dir, "num_tokens.npy")
- num_tokens = np.load(_cached_file)
- else:
- num_tokens = self.count_tokens(offsets, tok_cache_dir)
- else:
- num_tokens = None
-
- offsets = offsets
- num_tokens = num_tokens
-
- else:
- offsets = self.count_offsets()
- num_tokens = None
- if max_length is not None:
- assert self.tokenize_fn
- num_tokens = self.count_tokens(offsets)
-
- _sampled = [i for i in range(len(offsets))]
-
- if max_length is not None:
- assert isinstance(max_length, int)
- _filtered = [
- x for i, x in enumerate(_sampled) if num_tokens[i] < max_length
- ]
-
- if len(_filtered) < len(_sampled):
- missed_num = len(_sampled) - len(_filtered)
- logger.warning(
- f"{path} has {missed_num} prompt length>{max_length}, discard."
- )
-
- _sampled = _filtered
-
- _target_num_samples = int(len(_sampled) * sample_ratio)
- self.sampled = _sampled * int(sample_ratio)
- self.sampled.extend(
- random.sample(_sampled, _target_num_samples - len(self.sampled))
- )
-
- if num_tokens is not None:
- num_tokens = num_tokens[self.sampled]
-
- self.num_tokens = num_tokens
- self.offsets = offsets[self.sampled]
-
- def count_offsets(self, cache_dir=None):
- offsets = [0]
- with open(self.path) as f:
- lines = f.readlines()
- for line in lines[:-1]:
- offsets.append(offsets[-1] + len(line.encode()))
-
- offsets = np.array(offsets)
-
- if dist.get_rank() == 0 and cache_dir:
- save_path = os.path.join(cache_dir, "offsets.npy")
- np.save(save_path, offsets)
-
- return offsets
-
- def _tokenize_by_offset(self, offset):
- with open(self.path) as f:
- f.seek(offset)
- data = json.loads(f.readline())
- return self.tokenize_fn(data)
-
- def count_tokens(self, offsets, cache_dir=None):
- num_samples = len(offsets)
-
- if dist.is_available():
- world_size = dist.get_world_size()
- rank = dist.get_rank()
- else:
- world_size = 1
- rank = 0
-
- num_per_rank = math.ceil(num_samples / world_size)
-
- start = rank * num_per_rank
- end = (rank + 1) * num_per_rank
- offsets_shard = offsets[start:end]
-
- desc = f"[Rank {rank}] {self.path}"
- chunk_size = min(1024, max(1, len(offsets_shard) // self.tokenizer_workers))
-
- mp_context = multiprocessing.get_context("fork")
- with ProcessPoolExecutor(
- max_workers=self.tokenizer_workers, mp_context=mp_context
- ) as executor:
- tokenized = list(
- tqdm(
- executor.map(
- self._tokenize_by_offset, offsets_shard, chunksize=chunk_size
- ),
- desc=desc,
- total=len(offsets_shard),
- )
- )
-
- _num_tokens = [data["num_tokens"] for data in tokenized]
- _num_tokens = np.array(_num_tokens)
-
- if dist.is_available():
- num_tokens = [None] * world_size
- dist.all_gather_object(num_tokens, _num_tokens)
- num_tokens = np.concatenate(num_tokens, axis=0)
- else:
- num_tokens = _num_tokens
-
- if rank == 0 and cache_dir:
- save_path = os.path.join(cache_dir, "num_tokens.npy")
- np.save(save_path, num_tokens)
-
- return num_tokens
-
- def __len__(self):
- return len(self.offsets)
-
- def __getitem__(self, item):
- """Returns a dict containing packed data in the given item.
-
- Args:
- item: An index to retrieve packed data.
-
- Returns:
- A dict including packed input_ids, labels, and cumulative_len.
- """
- with open(self.path) as f:
- f.seek(self.offsets[item])
- line = f.readline()
-
- raw_data = json.loads(line)
-
- if self.tokenize_fn:
- tokenized_data = self.tokenize_fn(raw_data)
- return tokenized_data
- else:
- return raw_data
diff --git a/code/xtuner/_lite/datasets/pack.py b/code/xtuner/_lite/datasets/pack.py
deleted file mode 100644
index 2e294d4cb14d4df210cb1cee17c76d17a1204d2e..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/pack.py
+++ /dev/null
@@ -1,257 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import bisect
-import itertools
-import random
-
-import numpy as np
-import torch
-from datasets import Dataset, concatenate_datasets
-from torch.utils.data import ConcatDataset
-
-
-class SoftPackDataset(torch.utils.data.Dataset):
- def __init__(self, datasets, target=2048, blend=False, sort=False):
- if blend:
- num_tokens = [np.concatenate([dset.num_tokens for dset in datasets])]
- datasets = [ConcatDataset(datasets)]
- else:
- num_tokens = [dset.num_tokens for dset in datasets]
- self.datasets = datasets
- self.target = target
-
- pack_infos = []
- for i, dataset in enumerate(self.datasets):
- _infos = self.get_pack_infos(dataset, i, num_tokens[i])
- pack_infos.append(_infos)
- self.pack_infos = concatenate_datasets(pack_infos)
-
- @property
- def longest(self):
- return self.pack_infos["longest"]
-
- def get_pack_infos(self, dataset, dataset_id, num_tokens):
- # _ori_lens = dataset['num_tokens']
- inds = [i for i in range(len(dataset))]
- random.shuffle(inds)
-
- item_buffer = []
- length_buffer = []
- longest = 0
-
- pack_infos = []
- for shfl_i in inds:
- if num_tokens[shfl_i] + sum(length_buffer) <= self.target:
- item_buffer.append(shfl_i)
- length_buffer.append(num_tokens[shfl_i])
- longest = max(longest, num_tokens[shfl_i])
- else:
- if len(item_buffer) > 0:
- info = {
- "dataset_id": dataset_id,
- "indices": item_buffer,
- "longest": int(longest),
- }
- pack_infos.append(info)
-
- item_buffer = [shfl_i]
- length_buffer = [num_tokens[shfl_i]]
- longest = num_tokens[shfl_i]
-
- if len(item_buffer) > 0:
- info = {
- "dataset_id": dataset_id,
- "indices": item_buffer,
- "longest": int(longest),
- }
-
- pack_infos.append(info)
-
- pack_infos = Dataset.from_list(pack_infos)
-
- return pack_infos
-
- def __len__(self):
- return len(self.pack_infos)
-
- def __getitem__(self, item):
- indices = self.pack_infos[item]["indices"]
- dataset_id = self.pack_infos[item]["dataset_id"]
- return [self.datasets[dataset_id][i] for i in indices]
-
-
-class HardPackDataset(torch.utils.data.Dataset):
- def __init__(self, datasets, target=2048, blend=True, sort=False):
- if blend:
- num_tokens = [np.concatenate([dset.num_tokens for dset in datasets])]
- datasets = [ConcatDataset(datasets)]
- else:
- num_tokens = [dset.num_tokens for dset in datasets]
- self.datasets = datasets
- self.target = target
-
- pack_infos = []
- for i, dataset in enumerate(self.datasets):
- _info = self.get_pack_info(dataset, i, num_tokens[i])
- pack_infos.append(_info)
-
- _ranges_left = []
- _ranges_right = []
- _num_packed_samples = []
- _indices = []
- _max_length_per_pack = []
- _dataset_id = []
- for info in pack_infos:
- _ranges_left.extend(info["ranges_left"])
- _ranges_right.extend(info["ranges_right"])
- _num_packed_samples.append(info["num_packed_samples"])
- _indices.extend(info["indices"])
- _max_length_per_pack.extend(info["max_length_per_pack"])
- _dataset_id.extend(info["dataset_id"])
-
- self.pack_infos = {
- "ranges_left": _ranges_left,
- "ranges_right": _ranges_right,
- "num_packed_samples": _num_packed_samples,
- "indices": _indices,
- "max_length_per_pack": _max_length_per_pack,
- "dataset_id": _dataset_id,
- }
-
- @classmethod
- def _cal_max_length(cls, begin, end, shfl_item_rngs_left, shfl_item_rngs_right):
- left = bisect.bisect(shfl_item_rngs_right, begin)
- right = bisect.bisect(shfl_item_rngs_left, end)
- max_length = 0
- for i in range(left, right):
- item_begin = shfl_item_rngs_left[i]
- item_end = shfl_item_rngs_right[i]
- inner_l = max(begin, item_begin) - item_begin
- inner_r = min(end, item_end) - item_begin
- trunc_size = inner_r - inner_l
- max_length = max(max_length, trunc_size)
- return max_length
-
- def get_pack_info(self, dataset, dataset_id, num_tokens):
- # The number of data items after packing
- num_packed_samples = int(num_tokens.sum() / self.target)
-
- # Shuffle the order of the original dataset
- # The packing will proceed according to the order after shuffle.
- # Assume the following conditions hold:
- # (1) shfl_inds = [3, 1, 2, 0]
- # (2) self._ori_lens[3] + self._ori_lens[1] = max_length
- # (3) self._ori_lens[2] + self._ori_lens[0] = max_length
- # Ultimately, dataset[3] and dataset[1] will be combined into a new
- # data, and dataset[2] and dataset[0] will be combined into a new data.
- inds = [i for i in range(len(dataset))]
- # if seed is not None:
- # random.seed(seed)
- random.shuffle(inds)
- shfl_inds = inds
-
- # shuffled cumulative lengths
- shfl_lens = [num_tokens[i] for i in shfl_inds]
- shfl_acc_lens = list(itertools.accumulate(shfl_lens))
-
- shfl_item_rngs_left = [0] + shfl_acc_lens[:-1]
- shfl_item_rngs_right = shfl_acc_lens
-
- max_length_per_pack = []
- belong_dataset_ids = []
- for i in range(num_packed_samples):
- begin = i * self.target
- end = (i + 1) * self.target
- max_length_per_pack.append(
- self._cal_max_length(
- begin, end, shfl_item_rngs_left, shfl_item_rngs_right
- )
- )
- belong_dataset_ids.append(dataset_id)
-
- pack_infos = {
- "ranges_left": shfl_item_rngs_left,
- "ranges_right": shfl_item_rngs_right,
- "num_packed_samples": num_packed_samples,
- "indices": shfl_inds,
- "dataset_id": belong_dataset_ids,
- "max_length_per_pack": max_length_per_pack,
- }
-
- # pack_infos = Dataset.from_list(pack_infos)
-
- return pack_infos
-
- def _pack_ids_and_labels_in_range(self, begin: int, end: int):
- """Packs ids and labels in a given range using bisection method.
-
- Args:
- begin: Index indicating the beginning of the range.
- end: Index indicating the end of the range.
-
- Returns:
- A tuple containing packed ids, labels, and cumulative lengths.
- """
-
- # Use binary search to find dataset positions that fall within begin
- # and end range
- left = bisect.bisect(self.pack_infos["ranges_left"], begin)
- right = bisect.bisect(self.pack_infos["ranges_right"], end)
-
- trunc_input_ids = []
- trunc_labels = []
- trunc_sizes = []
-
- for i in range(left, right):
- # Determine the real range we will cut in current original item
- item_begin = self.pack_infos["ranges_left"][i]
- item_end = self.pack_infos["ranges_right"][i]
-
- # Calculate exact positions within current dataset item
- inner_l = max(begin, item_begin) - item_begin
- inner_r = min(end, item_end) - item_begin
-
- # Get original data and labels
- ori_idx = self.pack_infos["indices"][i]
- ori_dataset_id = self.pack_infos["dataset_id"][i]
- ori_input_ids = self.datasets[ori_dataset_id][ori_idx]["input_ids"]
- ori_labels = self.datasets[ori_dataset_id][ori_idx]["labels"]
-
- # Add original data and labels from calculated positions
- # to trunc_ids and trunc_labels
- trunc_input_ids.extend(ori_input_ids[inner_l:inner_r])
- trunc_labels.extend(ori_labels[inner_l:inner_r])
- trunc_sizes.append(inner_r - inner_l)
-
- # return populated lists of truncated ids, labels and their cumulative
- # lengths
- return trunc_input_ids, trunc_labels, trunc_sizes
-
- def __len__(self):
- return len(self.pack_infos["indices"])
-
- def __getitem__(self, item):
- """Returns a dict containing packed data in the given item.
-
- Args:
- item: An index to retrieve packed data.
-
- Returns:
- A dict including packed input_ids, labels, and cumulative_len.
- """
- # The cumulative length from the start position of this data
- begin = item * self.target
- # The cumulative length from the end position of this data
- end = (item + 1) * self.target
-
- # Extract data within the range from the shuffled original dataset.
- _res = self._pack_ids_and_labels_in_range(begin, end)
- packed_input_ids, packed_labels, num_tokens = _res
- assert self.target == len(packed_input_ids) == len(packed_labels)
-
- packed = {
- "input_ids": packed_input_ids,
- "labels": packed_labels,
- "num_tokens": num_tokens,
- }
-
- return packed
diff --git a/code/xtuner/_lite/datasets/streaming.py b/code/xtuner/_lite/datasets/streaming.py
deleted file mode 100644
index db1c6da7d62053a4a41e41376b84f34487c9d0e1..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/streaming.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-
-
-class Streaming:
- def __init__(self, file, max_epoch=1):
- self.file = file
- self.offset = 0
- self.epoch = 1
- self.max_epoch = max_epoch
-
- def __iter__(self):
- return self
-
- def __next__(self):
- with open(self.file) as f:
- f.seek(self.offset)
- line = f.readline()
-
- if not line and self.epoch < self.max_epoch:
- self.offset = 0
- self.epoch += 1
- return next(self)
-
- elif not line and self.epoch == self.max_epoch:
- raise StopIteration
-
- self.offset = f.tell()
- return line
diff --git a/code/xtuner/_lite/datasets/utils/__init__.py b/code/xtuner/_lite/datasets/utils/__init__.py
deleted file mode 100644
index cc432ab6c1217ee7851d978461d456d38cbfa792..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/utils/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .convert import OPENAI_CONVERT_MAP
-from .load import DATASET_CLS_MAP, load_datasets
-from .utils import apply_exif_orientation, move_data_to_device
-
-__all__ = [
- "OPENAI_CONVERT_MAP",
- "DATASET_CLS_MAP",
- "load_datasets",
- "apply_exif_orientation",
- "move_data_to_device",
-]
diff --git a/code/xtuner/_lite/datasets/utils/convert.py b/code/xtuner/_lite/datasets/utils/convert.py
deleted file mode 100644
index 9db4d3d2190ec36a594cde6cb3a09af02f3ea25f..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/utils/convert.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import re
-
-from xtuner._lite.chat import ChatMessages
-
-
-class XTunerFormat2Openai:
- @classmethod
- def source_format(cls):
- data = {
- "conversation": [
- {"system": "SYSTEM", "input": "INPUT", "output": "OUTPUT"},
- {"input": "INPUT", "output": "OUTPUT"},
- ]
- }
- return data
-
- @classmethod
- def target_format(cls):
- data = {
- "messages": [
- {"role": "system", "content": "SYSTEM"},
- {"role": "user", "content": "INPUT"},
- {"role": "assistant", "content": "OUTPUT"},
- {"role": "user", "content": "INPUT"},
- {"role": "assistant", "content": "OUTPUT"},
- ]
- }
- return data
-
- @staticmethod
- def convert(data):
- ROLE_MAPPING = {"system": "system", "input": "user", "output": "assistant"}
- messages = []
- for single_turn_conversation in data["conversation"]:
- for role, content in single_turn_conversation.items():
- messages.append({"role": ROLE_MAPPING[role], "content": content})
- return ChatMessages.from_dict({"messages": messages})
-
-
-class Alpaca2Openai:
- @classmethod
- def source_format(cls):
- data = {
- "instruction": "INSTRUCTION",
- "input": "INPUT",
- "output": "OUTPUT",
- }
- return data
-
- @classmethod
- def target_format(cls):
- data = {
- "messages": [
- {"role": "user", "content": "INSTRUCTION\nINPUT"},
- {"role": "assistant", "content": "OUTPUT"},
- ]
- }
- return data
-
- @staticmethod
- def convert(data):
- if data.get("output") == "":
- return ChatMessages.from_dict({"messages": []})
- else:
- return ChatMessages.from_dict(
- {
- "messages": [
- {
- "role": "user",
- "content": f"{data['instruction']}\n{data['input']}",
- },
- {"role": "assistant", "content": f"{data['output']}"},
- ]
- }
- )
-
-
-def llava_to_openai(data):
- image_token = ""
- conversations = data["conversations"]
- messages = []
-
- if "image" in data:
- image_urls = data["image"]
- if isinstance(image_urls, str):
- image_urls = [image_urls]
- else:
- image_urls = None
-
- while conversations and conversations[0]["from"] == "gpt":
- # Skip the first one if it is from gpt
- conversations = conversations[1:]
-
- image_id = 0
- for convs in conversations:
- if convs["from"] == "human":
- pattern = f"({image_token})"
- chunks = re.split(pattern, convs["value"])
-
- text_content = []
- img_content = []
-
- for chunk in chunks:
- if chunk == image_token:
- url = image_urls[image_id]
- if not isinstance(url, str):
- raise TypeError(data)
- # assert , image_url
- item = dict(type="image_url", image_url=url)
- img_content.append(item)
- image_id += 1
- elif len(chunk.strip()):
- item = dict(type="text", text=chunk.strip())
- text_content.append(item)
-
- msg = {"role": "user", "content": img_content + text_content}
- messages.append(msg)
-
- elif convs["from"] == "gpt":
- msg = {"role": "assistant", "content": convs["value"]}
- messages.append(msg)
- else:
- raise NotImplementedError
-
- return ChatMessages.from_dict({"messages": messages})
-
-
-def llava_to_openai_interleave(data):
- image_token = ""
- conversations = data["conversations"]
- messages = []
-
- if "image" in data:
- image_urls = data["image"]
- if isinstance(image_urls, str):
- image_urls = [image_urls]
- else:
- image_urls = None
-
- while conversations and conversations[0]["from"] == "gpt":
- # Skip the first one if it is from gpt
- conversations = conversations[1:]
-
- image_id = 0
- for convs in conversations:
- if convs["from"] == "human":
- pattern = f"({image_token})"
- chunks = re.split(pattern, convs["value"])
-
- content = []
-
- for chunk in chunks:
- if chunk == image_token:
- url = image_urls[image_id]
- if not isinstance(url, str):
- raise TypeError(data)
- # assert , image_url
- item = dict(type="image_url", image_url=url)
- content.append(item)
- image_id += 1
- elif len(chunk.strip()):
- item = dict(type="text", text=chunk.strip())
- content.append(item)
-
- msg = {"role": "user", "content": content}
- messages.append(msg)
-
- elif convs["from"] == "gpt":
- msg = {"role": "assistant", "content": convs["value"]}
- messages.append(msg)
- else:
- raise NotImplementedError
-
- return ChatMessages.from_dict({"messages": messages})
-
-
-def official_openai(data):
- if "messages" in data:
- return ChatMessages.from_dict(data)
- elif "message_data" in data:
- return ChatMessages.from_dict({"messages": data["message_data"]})
- elif "dialogs" in data:
- return ChatMessages.from_dict({"messages": data["dialogs"]})
- else:
- return ChatMessages.from_dict({"messages": data})
-
-
-OPENAI_CONVERT_MAP = {
- "llava": llava_to_openai,
- "llava_interleave": llava_to_openai_interleave,
- "alpaca": Alpaca2Openai.convert,
- "xtuner": XTunerFormat2Openai.convert,
- "openai": official_openai,
-}
diff --git a/code/xtuner/_lite/datasets/utils/load.py b/code/xtuner/_lite/datasets/utils/load.py
deleted file mode 100644
index b6f7befb6146189899538a325c38a28e97ae96fb..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/utils/load.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import json
-import math
-import os
-import random
-import re
-
-from torch import distributed as dist
-from tqdm import tqdm
-
-from xtuner._lite import get_logger
-
-from ..json import JsonDataset
-from ..jsonl import JsonlDataset
-
-logger = get_logger()
-
-DATASET_CLS_MAP = {".jsonl": JsonlDataset, ".json": JsonDataset}
-
-
-def load_hf_dataset(path, split="train", sample_ratio=1.0, cache_dir=None, map_fn=None):
- from datasets import load_dataset
-
- dataset = load_dataset(path)[split]
-
- if map_fn:
- dataset = dataset.map(map_fn, num_proc=8)
-
- if sample_ratio != 1:
- ori_samples = len(dataset)
- target_samples = int(sample_ratio * ori_samples)
- indices = random.choices([i for i in range(ori_samples)], k=target_samples)
- dataset = dataset.select(indices)
-
- dataset = dataset.to_list()
-
- # if init_fn:
- # dataset = init_fn(dataset)
-
- # if cache_dir and isinstance(dataset, CacheDataset):
- # dataset.cache(cache_dir)
-
- return dataset
-
-
-def load_from_cache(cache_dir, init_fn):
- if dist.is_available():
- world_size = dist.get_world_size()
- rank = dist.get_rank()
- else:
- world_size = 1
- rank = 0
-
- sub_cache_dirs = []
- for _path in tqdm(os.listdir(cache_dir)):
- path = os.path.join(cache_dir, _path)
- if os.path.isdir(path):
- sub_cache_dirs.append(path)
-
- num_dsets = len(sub_cache_dirs)
- avg_num = math.ceil(num_dsets / world_size)
- start = rank * avg_num
- end = min((rank + 1) * avg_num, num_dsets)
- desc = f"[Rank {rank}] Loading Cached Dataset"
-
- rank_datasets = []
- for ind in tqdm(range(start, end), desc=desc):
- dset = init_fn(sub_cache_dirs[ind])
- rank_datasets.append(dset)
-
- if dist.is_available() and world_size > 1:
- dist.barrier()
- buffers = [None] * world_size
- dist.all_gather_object(buffers, rank_datasets)
- world_datasets = []
- for dsets_per_rank in buffers:
- world_datasets.extend(dsets_per_rank)
-
- assert len(world_datasets) == num_dsets
- else:
- world_datasets = rank_datasets
-
- return world_datasets
-
-
-def load_local_datasets(
- paths,
- file_types,
- file_pattern=None,
- cache_dir=None,
- sample_ratios=1.0,
- map_fns=None,
- max_length=None,
-):
- if isinstance(paths, str):
- paths = [paths]
-
- if isinstance(sample_ratios, (tuple, list)):
- if len(sample_ratios) == 1:
- sample_ratios = list(sample_ratios) * len(paths)
-
- if len(sample_ratios) != len(paths):
- raise RuntimeError(
- f"There are {len(paths)} paths, but only "
- f"{len(sample_ratios)} sample ratios were set."
- )
-
- if map_fns is None:
- map_fns = [None] * len(paths)
-
- if isinstance(map_fns, (tuple, list)):
- if len(map_fns) == 1:
- map_fns = list(map_fns) * len(paths)
-
- if len(map_fns) != len(paths):
- raise RuntimeError(
- f"There are {len(paths)} paths, but only"
- f"{len(map_fns)} map fns were set."
- )
-
- files = []
- file_sample_ratios = []
- file_map_fns = []
-
- for pid, path in enumerate(paths):
- if os.path.isdir(path):
- dir_files = []
- for root, dirs, _files in os.walk(path, followlinks=True):
- dirs.sort()
- for relative_path in sorted(_files):
- suffix = os.path.splitext(relative_path)[-1]
- absolute_path = os.path.join(root, relative_path)
- if file_pattern is not None:
- if bool(re.match(file_pattern, absolute_path)):
- dir_files.append(absolute_path)
- elif suffix in file_types:
- dir_files.append(absolute_path)
-
- _num_dir_files = len(dir_files)
- if _num_dir_files == 0:
- raise RuntimeError(
- f"There are no files with the suffix {file_types}" f"in `{path}`."
- )
-
- logger.info(f"Found {len(dir_files)} files in {path}")
- files.extend(dir_files)
- file_sample_ratios.extend([sample_ratios[pid]] * _num_dir_files)
- file_map_fns.extend([map_fns[pid]] * _num_dir_files)
-
- elif os.path.isfile(path):
- files.append(path)
- file_sample_ratios.append(sample_ratios[pid])
- file_map_fns.append(map_fns[pid])
-
- else:
- raise RuntimeError(f"`{path}` not found.")
-
- num_files = len(files)
-
- datasets = []
- for i in range(num_files):
- _path = files[i]
- _ratio = file_sample_ratios[i]
- _map_fn = file_map_fns[i]
- _suffix = os.path.splitext(_path)[-1]
-
- dataset_cls = DATASET_CLS_MAP[_suffix]
- _dataset = dataset_cls(_path, _ratio, _map_fn, cache_dir, max_length)
- datasets.append(_dataset)
-
- return datasets
-
-
-def load_datasets(
- paths,
- sources="local",
- sample_ratios=1.0,
- file_types=DATASET_CLS_MAP.keys(),
- file_pattern=None,
- cache_dir=None,
- map_fns=None,
- max_length=None,
-):
- if isinstance(paths, str):
- paths = [paths]
-
- num_paths = len(paths)
-
- if isinstance(sample_ratios, (float, int)):
- sample_ratios = [sample_ratios] * num_paths
-
- if isinstance(sample_ratios, (tuple, list)):
- if len(sample_ratios) == 1:
- sample_ratios = list(sample_ratios) * num_paths
-
- if len(sample_ratios) != num_paths:
- raise RuntimeError(
- f"There are {num_paths} paths, but only "
- f"{len(sample_ratios)} sample ratios were set."
- )
-
- if isinstance(sources, str):
- sources = [sources]
-
- if isinstance(sources, (tuple, list)):
- if len(sources) == 1:
- sources = list(sources) * num_paths
-
- if len(sources) != num_paths:
- raise RuntimeError(
- f"There are {num_paths} paths, but only "
- f"{len(sources)} sources were set."
- )
-
- if not isinstance(map_fns, (tuple, list)):
- map_fns = [map_fns] * num_paths
-
- if isinstance(map_fns, (tuple, list)):
- if len(map_fns) == 1:
- map_fns = list(map_fns) * num_paths
-
- if len(map_fns) != num_paths:
- raise RuntimeError(
- f"There are {num_paths} paths, but only"
- f"{len(map_fns)} map fns were set."
- )
-
- local_inds = [i for i, src in enumerate(sources) if src == "local"]
- local_paths = [paths[ind] for ind in local_inds]
- local_map_fns = [map_fns[ind] for ind in local_inds]
- local_sample_ratios = [sample_ratios[ind] for ind in local_inds]
-
- hf_inds = [i for i, src in enumerate(sources) if src == "huggingface"]
- hf_paths = [paths[ind] for ind in hf_inds]
- hf_map_fns = [map_fns[ind] for ind in hf_inds]
- hf_sample_ratios = [sample_ratios[ind] for ind in hf_inds]
-
- datasets = []
- if len(local_inds):
- local_datasets = load_local_datasets(
- local_paths,
- file_types,
- file_pattern,
- cache_dir,
- local_sample_ratios,
- local_map_fns,
- max_length,
- )
- datasets.extend(local_datasets)
-
- if len(hf_inds):
- cached_infos = {}
- for i in range(len(hf_inds)):
- if cache_dir:
- digits = len(str(abs(len(hf_inds))))
- cache_id = f"cache-hf-{i+1:0{digits}}-of-" f"{len(hf_inds):0{digits}}"
- sub_cache_dir = os.path.join(cache_dir, cache_id)
- else:
- sub_cache_dir = None
- dset = load_hf_dataset(
- hf_paths[i],
- sample_ratio=hf_sample_ratios[i],
- map_fn=hf_map_fns[i],
- cache_dir=sub_cache_dir,
- max_length=max_length,
- )
- datasets.append(dset)
- breakpoint()
- if cache_dir:
- infos = {
- "path": hf_paths[i],
- "num_samples": dset.num_samples,
- "num_tokens": dset.total_tokens,
- }
- cached_infos[cache_id] = infos
-
- if cache_dir:
- _path = os.path.join(cache_dir, "hf_infos.json")
- with open(_path, "w") as f:
- json.dump(cached_infos, f)
-
- return datasets
-
-
-def load_ms_dataset():
- pass
diff --git a/code/xtuner/_lite/datasets/utils/utils.py b/code/xtuner/_lite/datasets/utils/utils.py
deleted file mode 100644
index 9b3422240490effd4ea00c9fbef8e5fb3304f28b..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/datasets/utils/utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from collections.abc import Mapping
-
-import torch
-from PIL import Image
-
-_EXIF_ORIENT = 274 # exif 'Orientation' tag
-
-
-def apply_exif_orientation(image):
- """Applies the exif orientation correctly.
-
- This code exists per the bug:
- https://github.com/python-pillow/Pillow/issues/3973
- with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
- various methods, especially `tobytes`
-
- Function based on:
- https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
- https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
-
- Args:
- image (PIL.Image): a PIL image
-
- Returns:
- (PIL.Image): the PIL image with exif orientation applied, if applicable
- """
- if not hasattr(image, "getexif"):
- return image
-
- try:
- exif = image.getexif()
- except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
- exif = None
-
- if exif is None:
- return image
-
- orientation = exif.get(_EXIF_ORIENT)
-
- method = {
- 2: Image.FLIP_LEFT_RIGHT,
- 3: Image.ROTATE_180,
- 4: Image.FLIP_TOP_BOTTOM,
- 5: Image.TRANSPOSE,
- 6: Image.ROTATE_270,
- 7: Image.TRANSVERSE,
- 8: Image.ROTATE_90,
- }.get(orientation)
-
- if method is not None:
- return image.transpose(method)
- return image
-
-
-def move_data_to_device(data, device="cuda"):
- """Prepares one `data` before feeding it to the model, be it a tensor or a
- nested list/dictionary of tensors."""
- if isinstance(data, Mapping):
- return type(data)({k: move_data_to_device(v) for k, v in data.items()})
- elif isinstance(data, (tuple, list)):
- return type(data)(move_data_to_device(v) for v in data)
- elif isinstance(data, torch.Tensor):
- kwargs = {"device": device}
- return data.to(non_blocking=True, **kwargs)
- return data
diff --git a/code/xtuner/_lite/device.py b/code/xtuner/_lite/device.py
deleted file mode 100644
index ff74ebe217346feded3617ca635b6a6dee464a1a..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/device.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-
-
-def get_device():
- device = None
- if torch.cuda.is_available():
- device = "cuda"
- else:
- try:
- import torch_npu # noqa: F401
-
- device = "npu"
- except ImportError:
- pass
- try:
- import torch_mlu # noqa: F401
-
- device = "mlu"
- except ImportError:
- pass
-
- if device is None:
- raise NotImplementedError(
- "Supports only CUDA or NPU. If your device is CUDA or NPU, "
- "please make sure that your environmental settings are "
- "configured correctly."
- )
-
- return device
-
-
-def get_torch_device_module():
- device = get_device()
- if device == "cuda":
- return torch.cuda
- elif device == "npu":
- return torch.npu
- elif device == "mlu":
- return torch.mlu
- else:
- raise NotImplementedError
diff --git a/code/xtuner/_lite/modelings/.DS_Store b/code/xtuner/_lite/modelings/.DS_Store
deleted file mode 100644
index 8812c366a52f59eacb5bb9ef2d5a910a867b264e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/_lite/modelings/.DS_Store and /dev/null differ
diff --git a/code/xtuner/_lite/modelings/__init__.py b/code/xtuner/_lite/modelings/__init__.py
deleted file mode 100644
index 1025f40879e66484eb0ae7e66357a88a4ef71530..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from .internlm2 import InternLM2Config, InternLM2ForCausalLM
-from .internlm3 import InternLM3Config, InternLM3ForCausalLM, InternLM3Tokenizer
-from .llava.modeling_llava import LlavaForConditionalGeneration
-from .llava.configuration_llava import EnhancedLlavaConfig
-from .llava.processing_llava import LlavaProcessor
-
-def register_remote_code():
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
- AutoConfig.register('internlm2', InternLM2Config, exist_ok=True)
- AutoModelForCausalLM.register(
- InternLM2Config, InternLM2ForCausalLM, exist_ok=True)
-
- AutoConfig.register('internlm3', InternLM3Config, exist_ok=True)
- AutoModelForCausalLM.register(
- InternLM3Config, InternLM3ForCausalLM, exist_ok=True)
- AutoTokenizer.register(
- InternLM3Config, InternLM3Tokenizer, exist_ok=True)
diff --git a/code/xtuner/_lite/modelings/internlm2/__init__.py b/code/xtuner/_lite/modelings/internlm2/__init__.py
deleted file mode 100644
index e43d72d4a59e6e2fa70ec01b701b996151451e20..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm2/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .configuration_internlm2 import InternLM2Config
-from .modeling_internlm2 import InternLM2ForCausalLM
diff --git a/code/xtuner/_lite/modelings/internlm2/configuration_internlm2.py b/code/xtuner/_lite/modelings/internlm2/configuration_internlm2.py
deleted file mode 100644
index 8b810794707c07014b2143b5fced9fe0c621c8eb..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm2/configuration_internlm2.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" InternLM2 model configuration"""
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
-
-
-# Modified from transformers.model.llama.configuration_llama.LlamaConfig
-class InternLM2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*, defaults to 32000):
- Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`InternLM2Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 11008):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer decoder.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 2048):
- The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*):
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 1):
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 2):
- End of stream token id.
- pretraining_tp (`int`, *optional*, defaults to 1):
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
- document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
- to understand more about it. This value is necessary to ensure exact reproducibility
- of the pretraining results. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232).
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
- these scaling strategies behave:
- https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
- experimental feature, subject to breaking API changes in future versions.
- """
- _auto_class = 'AutoConfig'
- model_type = 'internlm2'
- keys_to_ignore_at_inference = ['past_key_values']
-
- def __init__( # pylint: disable=W0102
- self,
- vocab_size=103168,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=None,
- hidden_act='silu',
- max_position_embeddings=2048,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- pretraining_tp=1,
- tie_word_embeddings=False,
- bias=True,
- rope_theta=10000,
- rope_scaling=None,
- attn_implementation=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.bias = bias
-
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
-
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.pretraining_tp = pretraining_tp
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self._rope_scaling_validation()
- self.attn_implementation = attn_implementation
- if self.attn_implementation is None:
- self.attn_implementation = 'eager'
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
- def _rope_scaling_validation(self):
- """
- Validate the `rope_scaling` configuration.
- """
- if self.rope_scaling is None:
- return
-
- if not isinstance(self.rope_scaling,
- dict) or len(self.rope_scaling) != 2:
- raise ValueError(
- '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
- f'got {self.rope_scaling}')
- rope_scaling_type = self.rope_scaling.get('type', None)
- rope_scaling_factor = self.rope_scaling.get('factor', None)
- if rope_scaling_type is None or rope_scaling_type not in [
- 'linear', 'dynamic'
- ]:
- raise ValueError(
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
- )
- if (rope_scaling_factor is None
- or not isinstance(rope_scaling_factor,
- (float, int)) or rope_scaling_factor < 1.0):
- raise ValueError(
- f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
- f'of type {type(rope_scaling_factor)}')
diff --git a/code/xtuner/_lite/modelings/internlm2/modeling_internlm2.py b/code/xtuner/_lite/modelings/internlm2/modeling_internlm2.py
deleted file mode 100644
index 69ddc61969e34c8ebe387e0dc8dfd1d85db576cf..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm2/modeling_internlm2.py
+++ /dev/null
@@ -1,1899 +0,0 @@
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch InternLM2.5 model."""
-import math
-import queue
-import threading
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from einops import rearrange
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache, StaticCache
-from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-from transformers.modeling_outputs import (BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutputWithPast,
- TokenClassifierOutput)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
-from transformers.utils import (add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_greater_or_equal_2_10, logging,
- replace_return_docstrings)
-
-try:
- from transformers.generation.streamers import BaseStreamer
-except Exception:
- BaseStreamer = None
-
-from .configuration_internlm2 import InternLM2Config
-
-try:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import (index_first_axis, pad_input,
- unpad_input)
-except:
- pass
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = 'InternLM2Config'
-
-
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-class InternLM2RMSNorm(nn.Module):
- """InternLM2RMSNorm is equivalent to T5LayerNorm."""
-
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance +
- self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
-
-
-class InternLM2RotaryEmbedding(nn.Module):
- """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0):
- super().__init__()
- self.scaling_factor = scaling_factor
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (
- self.base
- **(torch.arange(0, self.dim, 2,
- dtype=torch.int64).float().to(device) / self.dim))
- self.register_buffer('inv_freq', inv_freq, persistent=False)
- # For BC we register cos and sin cached
- self.max_seq_len_cached = max_position_embeddings
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- # x: [bs, num_attention_heads, seq_len, head_size]
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
- position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 since bfloat16 loses precision on long contexts
- # See https://github.com/huggingface/transformers/pull/29285
- device_type = x.device.type
- device_type = device_type if isinstance(
- device_type, str) and device_type != 'mps' else 'cpu'
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float()
- @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
-
- def forward(self, x, position_ids):
- # difference to the original RoPE: a scaling factor is aplied to the position ids
- position_ids = position_ids.float() / self.scaling_factor
- cos, sin = super().forward(x, position_ids)
- return cos, sin
-
-
-class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
- Credits to the Reddit users /u/bloc97 and /u/emozilla"""
-
- def forward(self, x, position_ids):
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_position_embeddings:
- base = self.base * ((self.scaling_factor * seq_len /
- self.max_position_embeddings) -
- (self.scaling_factor - 1))**(
- self.dim / (self.dim - 2))
- inv_freq = 1.0 / (
- base
- **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(
- x.device) / self.dim))
- self.register_buffer(
- 'inv_freq', inv_freq,
- persistent=False) # TODO joao: this may break with compilation
-
- cos, sin = super().forward(x, position_ids)
- return cos, sin
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class InternLM2MLP(nn.Module):
- """MLP for InternLM2 model."""
-
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.w1 = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.w3 = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.w2 = nn.Linear(
- self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
-
- return down_proj
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-class InternLM2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self,
- config: InternLM2Config,
- layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
- 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
- 'when creating this class.')
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
- f' and `num_heads`: {self.num_heads}).')
-
- self.wqkv = nn.Linear(
- self.hidden_size,
- (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
- bias=config.bias,
- )
- self.wo = nn.Linear(
- self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
-
- self._init_rope()
-
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = InternLM2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling['type']
- scaling_factor = self.config.rope_scaling['factor']
- if scaling_type == 'linear':
- self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == 'dynamic':
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- else:
- raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False, # pylint: disable=unused-argument
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- # split qkv_states by tp size
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
- qkv_states = torch.cat(
- [
- F.linear(hidden_states, qkv_slice)
- for qkv_slice in qkv_slices
- ],
- dim=-1 # pylint: disable=E1102
- )
- else:
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states,
- 'b q h gs d -> b q (h gs) d').transpose(1, 2)
- key_states = qkv_states[..., -2, :].transpose(1, 2)
- value_states = qkv_states[..., -1, :].transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(
- 2, 3)) / math.sqrt(self.head_dim)
-
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(
- attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
- f' {attn_output.size()}')
-
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.wo.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class InternLM2FlashAttention2(InternLM2Attention):
- """
- InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
- # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
- # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
- # produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if isinstance(past_key_value, StaticCache):
- raise ValueError(
- '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
- 'make sure to use `sdpa` in the mean time, and open an issue at '
- 'https://github.com/huggingface/transformers')
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # dropout_rate = self.attention_dropout if self.training else 0.0
- dropout_rate = 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (InternLM2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.wqkv.weight.dtype
-
- logger.warning_once(
- f'The input hidden states seems to be silently casted in float32, this might be related to'
- f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
- f' {target_dtype}.')
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate)
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.hidden_size).contiguous()
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value # pylint: disable=E0606
-
- def _flash_attention_forward(self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`float`):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- """
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
- # For details, please see the comment in InternLM2FlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask,
- query_length)
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
- query_length) # pylint: disable=E0606
- else:
- attn_output = flash_attn_func( # pylint: disable=E0606
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal)
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
- query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
- attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis( # pylint: disable=E0606
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- value_layer = index_first_axis( # pylint: disable=E0606
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- if query_length == kv_seq_len:
- query_layer = index_first_axis( # pylint: disable=E0606
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
- head_dim), indices_k)
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
- query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
-class InternLM2SdpaAttention(InternLM2Attention):
- """
- InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
- `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
- to adapt to SDPA API.
- """
-
- # Adapted from InternLM2Attention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
- # once this is implemented.
- logger.warning_once(
- 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` '
- 'does not support `output_attentions=True`. Falling back to the manual attention implementation, '
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- causal_mask = attention_mask
- if attention_mask is not None:
- causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
- # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == 'cuda' and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
- # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
- # options. An inline conditional prevents dynamic shapes from compiling.
- is_causal = bool(causal_mask is None and q_len > 1)
-
- attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=0.0,
- is_causal=is_causal,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
-
- attn_output = self.wo(attn_output)
-
- return attn_output, None, past_key_value
-
-
-INTERNLM2_ATTENTION_CLASSES = {
- 'eager': InternLM2Attention,
- 'flash_attention_2': InternLM2FlashAttention2,
- 'sdpa': InternLM2SdpaAttention,
-}
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
-class InternLM2DecoderLayer(nn.Module):
- """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
-
- def __init__(self, config: InternLM2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
-
- self.attention = INTERNLM2_ATTENTION_CLASSES[
- config.attn_implementation](
- config=config, layer_idx=layer_idx)
-
- self.feed_forward = InternLM2MLP(config)
- self.attention_norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
- self.ffn_norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
- torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- residual = hidden_states
-
- hidden_states = self.attention_norm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.attention(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.ffn_norm(hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states, )
-
- if output_attentions:
- outputs += (self_attn_weights, )
-
- if use_cache:
- outputs += (present_key_value, )
-
- return outputs
-
-
-InternLM2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`InternLM2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
-@add_start_docstrings(
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
- InternLM2_START_DOCSTRING,
-)
-class InternLM2PreTrainedModel(PreTrainedModel):
- """
- InternLM2 pretraiend model's base class.
- """
-
- config_class = InternLM2Config
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
- _no_split_modules = ['InternLM2DecoderLayer']
- _skip_keys_device_placement = ['past_key_values']
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-InternLM2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance;
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
-@add_start_docstrings(
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
- InternLM2_START_DOCSTRING,
-)
-class InternLM2Model(InternLM2PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
- Args:
- config: InternLM2Config
- """
-
- _auto_class = 'AutoModel'
-
- def __init__(self, config: InternLM2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.config = config
-
- self.tok_embeddings = nn.Embedding(config.vocab_size,
- config.hidden_size,
- self.padding_idx)
-
- self.layers = nn.ModuleList([
- InternLM2DecoderLayer(config, layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self.norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
- )
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.tok_embeddings(input_ids)
-
- return_legacy_cache = False
- if use_cache and not isinstance(
- past_key_values,
- Cache): # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = True
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length(
- ) if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens,
- past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device)
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
- cache_position, past_key_values,
- output_attentions)
-
- # embed positions
- hidden_states = inputs_embeds
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[
- 2 if output_attentions else 1]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1], )
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
-
- if not return_dict:
- return tuple(
- v for v in
- [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
- # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
- # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
- # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
- # See more context in https://github.com/huggingface/transformers/pull/29114
-
- if self.config.attn_implementation == 'flash_attention_2':
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length(
- ) if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- min_dtype = torch.finfo(dtype).min
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_length()
- else:
- target_length = (
- attention_mask.shape[-1] if isinstance(
- attention_mask, torch.Tensor) else past_seen_tokens +
- sequence_length + 1)
-
- if attention_mask is not None and attention_mask.dim() == 4:
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
- if attention_mask.max() != 0:
- raise ValueError(
- 'Custom 4D attention mask should be passed in inverted form with max==0`'
- )
- causal_mask = attention_mask
- else:
- causal_mask = torch.full((sequence_length, target_length),
- fill_value=min_dtype,
- dtype=dtype,
- device=device)
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(
- target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(
- input_tensor.shape[0], 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone(
- ) # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :
- mask_length] + attention_mask[:,
- None,
- None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :
- mask_length] = causal_mask[:, :, :, :
- mask_length].masked_fill(
- padding_mask,
- min_dtype)
- if (self.config.attn_implementation == 'sdpa'
- and attention_mask is not None
- and attention_mask.device.type == 'cuda'
- and not output_attentions):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- causal_mask = AttentionMaskConverter._unmask_unattended(
- causal_mask, min_dtype) # pylint: disable=E1120
-
- return causal_mask
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
-class InternLM2ForCausalLM(InternLM2PreTrainedModel):
- """Causal language model (CLM) for InternLM2."""
-
- _auto_class = 'AutoModelForCausalLM'
- _tied_weights_keys = ['output.weight']
-
- def __init__(self, config):
- super().__init__(config)
- self.model = InternLM2Model(config)
- self.vocab_size = config.vocab_size
- self.output = nn.Linear(
- config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- def get_output_embeddings(self):
- return self.output
-
- def set_output_embeddings(self, new_embeddings):
- self.output = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- @replace_return_docstrings(
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
- >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
-
- hidden_states = outputs[0]
- if self.config.pretraining_tp > 1:
- output_slices = self.output.weight.split(
- self.vocab_size // self.config.pretraining_tp, dim=0)
- logits = [
- F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
- for i in range(self.config.pretraining_tp)
- ]
- logits = torch.cat(logits, dim=-1)
- else:
- logits = self.output(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits, ) + outputs[1:]
- return (loss, ) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- use_cache=True,
- **kwargs,
- ):
- past_length = 0
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- past_length = cache_position[
- 0] if cache_position is not None else past_key_values.get_seq_length(
- )
- max_cache_length = (
- torch.tensor(
- past_key_values.get_max_length(),
- device=input_ids.device)
- if past_key_values.get_max_length() is not None else None)
- cache_length = past_length if max_cache_length is None else torch.min(
- max_cache_length, past_length)
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
- if attention_mask is not None and attention_mask.shape[
- 1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] -
- past_length):]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
-
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
- if (max_cache_length is not None and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length):
- attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
-
- position_ids = kwargs.get('position_ids', None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1]:]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {'inputs_embeds': inputs_embeds}
- else:
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
- # recompiles graphs as the stride of the inputs is a guard.
- # Ref: https://github.com/huggingface/transformers/pull/29114
- # TODO: use `next_tokens` directly instead.
- model_inputs = {'input_ids': input_ids.contiguous()}
-
- input_length = position_ids.shape[
- -1] if position_ids is not None else input_ids.shape[-1]
- if cache_position is None:
- cache_position = torch.arange(
- past_length,
- past_length + input_length,
- device=input_ids.device)
- elif use_cache:
- cache_position = cache_position[-input_length:]
-
- model_inputs.update({
- 'position_ids': position_ids,
- 'cache_position': cache_position,
- 'past_key_values': past_key_values,
- 'use_cache': use_cache,
- 'attention_mask': attention_mask,
- })
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (tuple(
- past_state.index_select(0, beam_idx.to(past_state.device))
- for past_state in layer_past), )
- return reordered_past
-
- def build_inputs(self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- meta_instruction=''):
- if history is None:
- history = []
- if tokenizer.add_bos_token:
- prompt = ''
- else:
- prompt = tokenizer.bos_token
- if meta_instruction:
- prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
- for record in history:
- prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
- prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
- return tokenizer([prompt], return_tensors='pt')
-
- @torch.no_grad()
- def chat(
- self,
- tokenizer,
- query: str,
- history: Optional[List[Tuple[str, str]]] = None,
- streamer: Optional[BaseStreamer] = None,
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- meta_instruction:
- str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
- '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory '
- '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
- '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such '
- 'as English and 中文.',
- **kwargs,
- ):
- if history is None:
- history = []
- inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
- inputs = {
- k: v.to(self.device)
- for k, v in inputs.items() if torch.is_tensor(v)
- }
- # also add end-of-assistant token in eos token id to avoid unnecessary generation
- eos_token_id = [
- tokenizer.eos_token_id,
- tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]
- ]
- outputs = self.generate(
- **inputs,
- streamer=streamer,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- eos_token_id=eos_token_id,
- **kwargs,
- )
- outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
- response = tokenizer.decode(outputs, skip_special_tokens=True)
- response = response.split('<|im_end|>')[0]
- history = history + [(query, response)]
- return response, history
-
- @torch.no_grad()
- def stream_chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- **kwargs,
- ):
- if history is None:
- history = []
- """
- Return a generator in format: (response, history)
- Eg.
- ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
- ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
- """
- if BaseStreamer is None:
- raise ModuleNotFoundError(
- 'The version of `transformers` is too low. Please make sure '
- 'that you have installed `transformers>=4.28.0`.')
-
- response_queue = queue.Queue(maxsize=20)
-
- class ChatStreamer(BaseStreamer):
- """
- Streamer used in generate to print words one by one.
- """
-
- def __init__(self, tokenizer) -> None:
- super().__init__()
- self.tokenizer = tokenizer
- self.queue = response_queue
- self.query = query
- self.history = history
- self.response = ''
- self.cache = []
- self.received_inputs = False
- self.queue.put(
- (self.response, history + [(self.query, self.response)]))
-
- def put(self, value):
- if len(value.shape) > 1 and value.shape[0] > 1:
- raise ValueError('ChatStreamer only supports batch size 1')
- elif len(value.shape) > 1:
- value = value[0]
-
- if not self.received_inputs:
- # The first received value is input_ids, ignore here
- self.received_inputs = True
- return
-
- self.cache.extend(value.tolist())
- token = self.tokenizer.decode(
- self.cache, skip_special_tokens=True)
- if token.strip() != '<|im_end|>':
- self.response = self.response + token
- history = self.history + [(self.query, self.response)]
- self.queue.put((self.response, history))
- self.cache = []
- else:
- self.end()
-
- def end(self):
- self.queue.put(None)
-
- def stream_producer():
- return self.chat(
- tokenizer=tokenizer,
- query=query,
- streamer=ChatStreamer(tokenizer=tokenizer),
- history=history,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- **kwargs,
- )
-
- def consumer():
- producer = threading.Thread(target=stream_producer)
- producer.start()
- while True:
- res = response_queue.get()
- if res is None:
- return
- yield res
-
- return consumer()
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
-@add_start_docstrings(
- """
- The InternLM2 Model transformer with a sequence classification head on top (linear layer).
- [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
- """Sequence Classification Head for InternLM2 Model."""
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = InternLM2Model(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError(
- 'Cannot handle batch sizes > 1 if no padding token is defined.'
- )
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
- sequence_lengths = torch.eq(
- input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
- sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = 'regression'
- elif self.num_labels > 1 and (labels.dtype
- in (torch.long, torch.int)):
- self.config.problem_type = 'single_label_classification'
- else:
- self.config.problem_type = 'multi_label_classification'
-
- if self.config.problem_type == 'regression':
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == 'single_label_classification':
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == 'multi_label_classification':
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits, ) + transformer_outputs[1:]
- return ((loss, ) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
-@add_start_docstrings(
- """
-The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
-SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
- """Question Answering model for InternLM2."""
-
- base_model_prefix = 'transformer'
-
- def __init__(self, config):
- super().__init__(config)
- self.transformer = InternLM2Model(config)
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.transformer.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.transformer.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- start_positions: Optional[torch.LongTensor] = None,
- end_positions: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
- r"""
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- sequence_output = outputs[0]
-
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
-
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1).to(
- start_logits.device)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
-
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
-
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((total_loss, ) +
- output) if total_loss is not None else output
-
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
-@add_start_docstrings(
- """
- The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
- output) e.g. for Named-Entity-Recognition (NER) tasks.
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
- """Token classification model for InternLM2."""
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = InternLM2Model(config)
- if getattr(config, 'classifier_dropout', None) is not None:
- classifier_dropout = config.classifier_dropout
- elif getattr(config, 'hidden_dropout', None) is not None:
- classifier_dropout = config.hidden_dropout
- else:
- classifier_dropout = 0.1
- self.dropout = nn.Dropout(classifier_dropout)
- self.score = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.score(sequence_output)
-
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
-
- if not return_dict:
- output = (logits, ) + outputs[2:]
- return ((loss, ) + output) if loss is not None else output
-
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
diff --git a/code/xtuner/_lite/modelings/internlm3/__init__.py b/code/xtuner/_lite/modelings/internlm3/__init__.py
deleted file mode 100644
index a228b2903df97f8401f924ff55f4fc9130b274aa..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm3/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .configuration_internlm3 import InternLM3Config
-from .modeling_internlm3 import InternLM3ForCausalLM
-from .tokenization_internlm3 import InternLM3Tokenizer
diff --git a/code/xtuner/_lite/modelings/internlm3/configuration_internlm3.py b/code/xtuner/_lite/modelings/internlm3/configuration_internlm3.py
deleted file mode 100644
index d9f03eeb9d91670836665af2c83e151e25903c87..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm3/configuration_internlm3.py
+++ /dev/null
@@ -1,197 +0,0 @@
-# coding=utf-8
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" InternLM3 model configuration"""
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_rope_utils import rope_config_validation
-from transformers.utils import logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class InternLM3Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 151936):
- Vocabulary size of the InternLM3 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`InternLM3Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 22016):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 32):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 32768):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
- accordingly.
- Expected contents:
- `rope_type` (`str`):
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
- 'llama3'], with 'default' being the original RoPE implementation.
- `factor` (`float`, *optional*):
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
- original maximum pre-trained length.
- `original_max_position_embeddings` (`int`, *optional*):
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
- pretraining.
- `attention_factor` (`float`, *optional*):
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
- computation. If unspecified, it defaults to value recommended by the implementation, using the
- `factor` field to infer the suggested value.
- `beta_fast` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
- ramp function. If unspecified, it defaults to 32.
- `beta_slow` (`float`, *optional*):
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
- ramp function. If unspecified, it defaults to 1.
- `short_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `long_factor` (`List[float]`, *optional*):
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
- size divided by the number of attention heads divided by 2
- `low_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
- `high_freq_factor` (`float`, *optional*):
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
- qkv_bias (`bool`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key and value projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- bias (`bool`, *optional*, defaults to `False`):
- Whether to use a bias in o_proj, up_proj, down_proj and gate_proj layers.
- head_dim (`int`, *optional*):
- The attention head dimension. If None, it will default to hidden_size // num_heads
-
- ```python
- >>> from transformers import InternLM3Model, InternLM3Config
-
- >>> # Initializing a InternLM3 style configuration
- >>> configuration = InternLM3Config()
-
- >>> # Initializing a model from the InternLM3-8B style configuration
- >>> model = InternLM3Model(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "internlm3"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- # Default tensor parallel plan for base model `InternLM3`
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
-
- def __init__(
- self,
- vocab_size=128512,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=32,
- hidden_act="silu",
- max_position_embeddings=32768,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- tie_word_embeddings=False,
- rope_theta=10000.0,
- rope_scaling=None,
- qkv_bias=False,
- attention_dropout=0.0,
- bias=False,
- head_dim=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self.qkv_bias = qkv_bias
- self.attention_dropout = attention_dropout
- self.bias = bias
- self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
- # Validate the correctness of rotary position embeddings parameters
- # BC: if there is a 'type' field, move it to 'rope_type'.
- if self.rope_scaling is not None and "type" in self.rope_scaling:
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
- rope_config_validation(self)
-
- super().__init__(
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
diff --git a/code/xtuner/_lite/modelings/internlm3/modeling_internlm3.py b/code/xtuner/_lite/modelings/internlm3/modeling_internlm3.py
deleted file mode 100644
index a651c4830e3bf606a5f6987d8367b2c975dd7b35..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm3/modeling_internlm3.py
+++ /dev/null
@@ -1,825 +0,0 @@
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/internlm3/modular_internlm3.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_internlm3.py file directly. One of our CI enforces this.
-# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-from typing import Callable, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from transformers.utils import logging
-
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache, StaticCache
-from transformers.generation import GenerationMixin
-from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
-from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
-from transformers.processing_utils import Unpack
-from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
-from .configuration_internlm3 import InternLM3Config
-
-
-logger = logging.get_logger(__name__)
-_CONFIG_FOR_DOC = "InternLM3Config"
-
-
-class InternLM3MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
-):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
-
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- return attn_output, attn_weights
-
-
-class InternLM3Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: InternLM3Config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.qkv_bias)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.qkv_bias)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
-
-class InternLM3RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- InternLM3RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-
-
-class InternLM3DecoderLayer(nn.Module):
- def __init__(self, config: InternLM3Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = InternLM3Attention(config=config, layer_idx=layer_idx)
- self.mlp = InternLM3MLP(config)
- self.input_layernorm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
-
-class InternLM3RotaryEmbedding(nn.Module):
- def __init__(self, config: InternLM3Config, device=None):
- super().__init__()
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
-
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
-
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
-
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- # This .to() is needed if the model has been moved to a device after being initialized (because
- # the buffer is automatically moved, but not the original copy)
- self.original_inv_freq = self.original_inv_freq.to(device)
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
-
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-INTERNLM3_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`InternLM3Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
- INTERNLM3_START_DOCSTRING,
-)
-class InternLM3PreTrainedModel(PreTrainedModel):
- config_class = InternLM3Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["InternLM3DecoderLayer"]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-INTERNLM3_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
-
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
-
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-@add_start_docstrings(
- "The bare InternLM3 Model outputting raw hidden-states without any specific head on top.",
- INTERNLM3_START_DOCSTRING,
-)
-class InternLM3Model(InternLM3PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM3DecoderLayer`]
-
- Args:
- config: InternLM3Config
- """
-
- def __init__(self, config: InternLM3Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [InternLM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.norm = InternLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = InternLM3RotaryEmbedding(config=config)
- self.gradient_checkpointing = False
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(INTERNLM3_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache()
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
-
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
-
- hidden_states = inputs_embeds
-
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
-
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **flash_attn_kwargs,
- )
-
- hidden_states = layer_outputs[0]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- output = BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- return output if return_dict else output.to_tuple()
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
-
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
-
- return causal_mask
-
-
-class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
-
-
-class InternLM3ForCausalLM(InternLM3PreTrainedModel, GenerationMixin):
- _auto_class = 'AutoModelForCausalLM'
- _tied_weights_keys = ["lm_head.weight"]
- _tp_plan = {"lm_head": "colwise_rep"}
-
- def __init__(self, config):
- super().__init__(config)
- self.model = InternLM3Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(INTERNLM3_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- **kwargs: Unpack[KwargsForCausalLM],
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, InternLM3ForCausalLM
-
- >>> model = InternLM3ForCausalLM.from_pretrained("meta-internlm3/InternLM3-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-internlm3/InternLM3-2-7b-hf")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- **kwargs,
- )
-
- hidden_states = outputs[0]
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
-
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
\ No newline at end of file
diff --git a/code/xtuner/_lite/modelings/internlm3/tokenization_internlm3.py b/code/xtuner/_lite/modelings/internlm3/tokenization_internlm3.py
deleted file mode 100644
index 462a82b1ec413a106da1dd27bc44dc806e3b4395..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internlm3/tokenization_internlm3.py
+++ /dev/null
@@ -1,295 +0,0 @@
-import os
-from shutil import copyfile
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-
-import sentencepiece as spm
-from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
-from transformers.utils import logging
-
-if TYPE_CHECKING:
- from transformers.tokenization_utils_base import TextInput
-
-logger = logging.get_logger(__name__)
-
-VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
-
-SPIECE_UNDERLINE = "▁"
-
-
-class InternLM3Tokenizer(PreTrainedTokenizer):
- """
- Construct a InternLM3 tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
- no padding token in the original model.
-
- Args:
- vocab_file (`str`):
- Path to the vocabulary file.
- unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
- token instead.
- bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
- eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
- The end of sequence token.
- pad_token (`str` or `tokenizers.AddedToken`, *optional*):
- A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
- attention mechanisms or loss computation.
- sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
- Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
- SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
- to set:
-
- - `enable_sampling`: Enable subword regularization.
- - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
-
- - `nbest_size = {0,1}`: No sampling is performed.
- - `nbest_size > 1`: samples from the nbest_size results.
- - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
- using forward-filtering-and-backward-sampling algorithm.
-
- - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
- BPE-dropout.
-
- add_bos_token (`bool`, *optional*, defaults to `True`):
- Whether or not to add an `bos_token` at the start of sequences.
- add_eos_token (`bool`, *optional*, defaults to `False`):
- Whether or not to add an `eos_token` at the end of sequences.
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
- Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
- extra spaces.
- use_default_system_prompt (`bool`, *optional*, defaults to `False`):
- Whether or not the default system prompt for InternLM3 should be used.
- spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not to add spaces between special tokens.
- spaces_for_interleaved_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not to add spaces between special tokens that are interleaved with normal tokens.
- add_prefix_space (`bool`, *optional*, defaults to `True`):
- Whether or not to add an initial space to the input. This allows to treat the leading word just as any
- other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
- """
-
- vocab_files_names = VOCAB_FILES_NAMES
- model_input_names = ["input_ids", "attention_mask"]
- _auto_class = "AutoTokenizer"
-
- def __init__(
- self,
- vocab_file,
- unk_token="",
- bos_token="",
- eos_token="",
- pad_token=None,
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
- add_bos_token=True,
- add_eos_token=False,
- clean_up_tokenization_spaces=False,
- use_default_system_prompt=False,
- spaces_between_special_tokens=False,
- spaces_for_interleaved_special_tokens=False,
- add_prefix_space=True,
- **kwargs,
- ):
- self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
- bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
- eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
- unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
- pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
-
- self.vocab_file = vocab_file
- self.add_bos_token = add_bos_token
- self.add_eos_token = add_eos_token
- self.use_default_system_prompt = use_default_system_prompt
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
- self.sp_model.Load(vocab_file)
- self.add_prefix_space = add_prefix_space
- self.spaces_for_interleaved_special_tokens = spaces_for_interleaved_special_tokens
-
- vocab_size = self.sp_model.get_piece_size()
- self.decoder = {i: self.sp_model.id_to_piece(i) for i in range(vocab_size)}
-
- super().__init__(
- bos_token=bos_token,
- eos_token=eos_token,
- unk_token=unk_token,
- pad_token=pad_token,
- add_bos_token=add_bos_token,
- add_eos_token=add_eos_token,
- sp_model_kwargs=sp_model_kwargs,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- use_default_system_prompt=use_default_system_prompt,
- spaces_between_special_tokens=spaces_between_special_tokens,
- add_prefix_space=add_prefix_space,
- **kwargs,
- )
-
- def __getstate__(self):
- state = self.__dict__.copy()
- state["sp_model"] = None
- state["sp_model_proto"] = self.sp_model.serialized_model_proto()
- return state
-
- def __setstate__(self, d):
- self.__dict__.update(d)
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
- self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
-
- @property
- def vocab_size(self):
- """Returns vocab size"""
- return self.sp_model.get_piece_size()
-
- def get_vocab(self):
- """Returns vocab as a dict"""
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
- vocab.update(self.added_tokens_encoder)
- return vocab
-
- def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
- """
- Args:
- text: TextInput
- Simply calls PreTrainedTokenizer's method
- """
- return super().tokenize(text, **kwargs)
-
- def _tokenize(self, text, **kwargs):
- """
- Args:
- text: TextInput
- Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
- """
- return self.sp_model.encode(text, out_type=str)
-
- def _convert_token_to_id(self, token):
- """Converts a token (str) in an id using the vocab."""
- return self.sp_model.piece_to_id(token)
-
- def _convert_id_to_token(self, index):
- """Converts an index (integer) in a token (str) using the vocab."""
- return self.decoder.get(index, "")
-
- def convert_tokens_to_string(self, tokens):
- """Converts a sequence of tokens (string) in a single string."""
- # since we manually add the prefix space, we have to remove it when decoding
- if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
- tokens[0] = tokens[0][1:]
-
- current_sub_tokens = []
- out_string = ""
- prev_is_special = False
- for i, token in enumerate(tokens):
- # make sure that special tokens are not decoded using sentencepiece model
- if token in self.all_special_tokens:
- if not prev_is_special and i != 0 and self.spaces_for_interleaved_special_tokens:
- out_string += " "
- out_string += self.sp_model.decode(current_sub_tokens) + token
- prev_is_special = True
- current_sub_tokens = []
- else:
- if (
- prev_is_special
- and i == 1
- and self.add_prefix_space
- and not token.startswith(SPIECE_UNDERLINE)
- and self.spaces_for_interleaved_special_tokens
- ):
- out_string += " "
- current_sub_tokens.append(token)
- prev_is_special = False
- out_string += self.sp_model.decode(current_sub_tokens)
- return out_string
-
- def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
- """
- Save the vocabulary and special tokens file to a directory.
-
- Args:
- save_directory (`str`):
- The directory in which to save the vocabulary.
-
- Returns:
- `Tuple(str)`: Paths to the files saved.
- """
- if not os.path.isdir(save_directory):
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
- return
- out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"])
-
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
- copyfile(self.vocab_file, out_vocab_file)
- elif not os.path.isfile(self.vocab_file):
- with open(out_vocab_file, "wb") as fi:
- content_spiece_model = self.sp_model.serialized_model_proto()
- fi.write(content_spiece_model)
-
- return (out_vocab_file,)
-
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
-
- output = bos_token_id + token_ids_0 + eos_token_id
-
- if token_ids_1 is not None:
- output = output + bos_token_id + token_ids_1 + eos_token_id
-
- return output
-
- def get_special_tokens_mask(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
- ) -> List[int]:
- """
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
- special tokens using the tokenizer `prepare_for_model` method.
-
- Args:
- token_ids_0 (`List[int]`):
- List of IDs.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not the token list is already formatted with special tokens for the model.
-
- Returns:
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
- """
- if already_has_special_tokens:
- return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
-
- bos_token_id = [1] if self.add_bos_token else []
- eos_token_id = [1] if self.add_eos_token else []
-
- if token_ids_1 is None:
- return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
- return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id
-
- def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
- """
- Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
- sequence pair mask has the following format:
-
- ```
- 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
- | first sequence | second sequence |
- ```
-
- if token_ids_1 is None, only returns the first portion of the mask (0s).
-
- Args:
- token_ids_0 (`List[int]`):
- List of ids.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
-
- Returns:
- `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
- """
- bos_token_id = [self.bos_token_id] if self.add_bos_token else []
- eos_token_id = [self.eos_token_id] if self.add_eos_token else []
-
- output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
-
- if token_ids_1 is not None:
- output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
-
- return output
diff --git a/code/xtuner/_lite/modelings/internvl2/__init__.py b/code/xtuner/_lite/modelings/internvl2/__init__.py
deleted file mode 100644
index 8652be2d9b80d45e0b6a2e30ae881ea8e155e678..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internvl2/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .modeling_intern_vit import InternVisionModel
-
-__all__ = ['InternVisionModel']
diff --git a/code/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py b/code/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py
deleted file mode 100644
index 32f469c4bbfee021fe19e622245d16fd9ba0aae6..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internvl2/configuration_intern_vit.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# --------------------------------------------------------
-# InternVL
-# Copyright (c) 2024 OpenGVLab
-# Licensed under The MIT License [see LICENSE for details]
-# --------------------------------------------------------
-import os
-from typing import Union
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class InternVisionConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
- instantiate a vision encoder according to the specified arguments, defining the model architecture.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- num_channels (`int`, *optional*, defaults to 3):
- Number of color channels in the input images (e.g., 3 for RGB).
- patch_size (`int`, *optional*, defaults to 14):
- The size (resolution) of each patch.
- image_size (`int`, *optional*, defaults to 224):
- The size (resolution) of each image.
- qkv_bias (`bool`, *optional*, defaults to `False`):
- Whether to add a bias to the queries and values in the self-attention layers.
- hidden_size (`int`, *optional*, defaults to 3200):
- Dimensionality of the encoder layers and the pooler layer.
- num_attention_heads (`int`, *optional*, defaults to 25):
- Number of attention heads for each attention layer in the Transformer encoder.
- intermediate_size (`int`, *optional*, defaults to 12800):
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
- qk_normalization (`bool`, *optional*, defaults to `True`):
- Whether to normalize the queries and keys in the self-attention layers.
- num_hidden_layers (`int`, *optional*, defaults to 48):
- Number of hidden layers in the Transformer encoder.
- use_flash_attn (`bool`, *optional*, defaults to `True`):
- Whether to use flash attention mechanism.
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
- `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
- layer_norm_eps (`float`, *optional*, defaults to 1e-6):
- The epsilon used by the layer normalization layers.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
- drop_path_rate (`float`, *optional*, defaults to 0.0):
- Dropout rate for stochastic depth.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- initializer_factor (`float`, *optional*, defaults to 0.1):
- A factor for layer scale.
- """
-
- model_type = 'intern_vit_6b'
-
- def __init__(
- self,
- num_channels=3,
- patch_size=14,
- image_size=224,
- qkv_bias=False,
- hidden_size=3200,
- num_attention_heads=25,
- intermediate_size=12800,
- qk_normalization=True,
- num_hidden_layers=48,
- use_flash_attn=True,
- hidden_act='gelu',
- norm_type='rms_norm',
- layer_norm_eps=1e-6,
- dropout=0.0,
- drop_path_rate=0.0,
- attention_dropout=0.0,
- initializer_range=0.02,
- initializer_factor=0.1,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.dropout = dropout
- self.drop_path_rate = drop_path_rate
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.num_channels = num_channels
- self.patch_size = patch_size
- self.image_size = image_size
- self.initializer_range = initializer_range
- self.initializer_factor = initializer_factor
- self.attention_dropout = attention_dropout
- self.layer_norm_eps = layer_norm_eps
- self.hidden_act = hidden_act
- self.norm_type = norm_type
- self.qkv_bias = qkv_bias
- self.qk_normalization = qk_normalization
- self.use_flash_attn = use_flash_attn
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
-
- if 'vision_config' in config_dict:
- config_dict = config_dict['vision_config']
-
- if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
- logger.warning(
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
- )
-
- return cls.from_dict(config_dict, **kwargs)
\ No newline at end of file
diff --git a/code/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py b/code/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py
deleted file mode 100644
index a8d36d9e3de8c2e8f7afbb37144a13e4b0d8745b..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/internvl2/modeling_intern_vit.py
+++ /dev/null
@@ -1,432 +0,0 @@
-# --------------------------------------------------------
-# InternVL
-# Copyright (c) 2024 OpenGVLab
-# Licensed under The MIT License [see LICENSE for details]
-# --------------------------------------------------------
-from typing import Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from einops import rearrange
-from timm.models.layers import DropPath
-from torch import nn
-from transformers.activations import ACT2FN
-from transformers.modeling_outputs import (BaseModelOutput,
- BaseModelOutputWithPooling)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import logging
-
-from .configuration_intern_vit import InternVisionConfig
-
-try:
- from flash_attn.bert_padding import pad_input, unpad_input
- from flash_attn.flash_attn_interface import \
- flash_attn_varlen_qkvpacked_func
- has_flash_attn = True
-except:
- print('FlashAttention2 is not installed.')
- has_flash_attn = False
-
-logger = logging.get_logger(__name__)
-
-
-class FlashAttention(nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
-
- def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
- super().__init__()
- self.softmax_scale = softmax_scale
- self.dropout_p = attention_dropout
-
- def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
- max_s=None, need_weights=False):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
- if unpadded: (nnz, 3, h, d)
- key_padding_mask: a bool tensor of shape (B, S)
- """
- assert not need_weights
- assert qkv.dtype in [torch.float16, torch.bfloat16]
- assert qkv.is_cuda
-
- if cu_seqlens is None:
- batch_size = qkv.shape[0]
- seqlen = qkv.shape[1]
- if key_padding_mask is None:
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
- max_s = seqlen
- cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
- device=qkv.device)
- output = flash_attn_varlen_qkvpacked_func(
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
- softmax_scale=self.softmax_scale, causal=causal
- )
- output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
- else:
- nheads = qkv.shape[-2]
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
- output_unpad = flash_attn_varlen_qkvpacked_func(
- x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
- softmax_scale=self.softmax_scale, causal=causal
- )
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
- indices, batch_size, seqlen),
- 'b s (h d) -> b s h d', h=nheads)
- else:
- assert max_s is not None
- output = flash_attn_varlen_qkvpacked_func(
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
- softmax_scale=self.softmax_scale, causal=causal
- )
-
- return output, None
-
-
-class InternRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-try:
- from apex.normalization import FusedRMSNorm
-
- InternRMSNorm = FusedRMSNorm # noqa
-
- logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
-except ImportError:
- # using the normal InternRMSNorm
- pass
-except Exception:
- logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
- pass
-
-
-NORM2FN = {
- 'rms_norm': InternRMSNorm,
- 'layer_norm': nn.LayerNorm,
-}
-
-
-class InternVisionEmbeddings(nn.Module):
- def __init__(self, config: InternVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
-
- self.class_embedding = nn.Parameter(
- torch.randn(1, 1, self.embed_dim),
- )
-
- self.patch_embedding = nn.Conv2d(
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
- )
-
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
-
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
-
- def _get_pos_embed(self, pos_embed, H, W):
- target_dtype = pos_embed.dtype
- pos_embed = pos_embed.float().reshape(
- 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
- pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
- reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
- return pos_embed
-
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
- batch_size, _, height, width = patch_embeds.shape
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- position_embedding = torch.cat([
- self.position_embedding[:, :1, :],
- self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
- ], dim=1)
- embeddings = embeddings + position_embedding.to(target_dtype)
- return embeddings
-
-
-class InternAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: InternVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.use_flash_attn = config.use_flash_attn and has_flash_attn
- if config.use_flash_attn and not has_flash_attn:
- print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
- f' {self.num_heads}).'
- )
-
- self.scale = self.head_dim ** -0.5
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
- self.attn_drop = nn.Dropout(config.attention_dropout)
- self.proj_drop = nn.Dropout(config.dropout)
-
- self.qk_normalization = config.qk_normalization
-
- if self.qk_normalization:
- self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
-
- if self.use_flash_attn:
- self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
- self.proj = nn.Linear(self.embed_dim, self.embed_dim)
-
- def _naive_attn(self, x):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
-
- if self.qk_normalization:
- B_, H_, N_, D_ = q.shape
- q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
- k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
-
- attn = ((q * self.scale) @ k.transpose(-2, -1))
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
- qkv = self.qkv(x)
- qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
-
- if self.qk_normalization:
- q, k, v = qkv.unbind(2)
- q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
- k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
- qkv = torch.stack([q, k, v], dim=2)
-
- context, _ = self.inner_attn(
- qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
- )
- outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
- outs = self.proj_drop(outs)
- return outs
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
- return x
-
-
-class InternMLP(nn.Module):
- def __init__(self, config: InternVisionConfig):
- super().__init__()
- self.config = config
- self.act = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
-
-
-class InternVisionEncoderLayer(nn.Module):
- def __init__(self, config: InternVisionConfig, drop_path_rate: float):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.norm_type = config.norm_type
-
- self.attn = InternAttention(config)
- self.mlp = InternMLP(config)
- self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
- self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
-
- self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
- self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
- self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
- self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
- """
- hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1)
-
- hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
-
- return hidden_states
-
-
-class InternVisionEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`InternEncoderLayer`].
-
- Args:
- config (`InternConfig`):
- The corresponding vision configuration for the `InternEncoder`.
- """
-
- def __init__(self, config: InternVisionConfig):
- super().__init__()
- self.config = config
- # stochastic depth decay rule
- # TODO: error
- # dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
- dpr = np.linspace(0.0, float(config.drop_path_rate), int(config.num_hidden_layers))
- self.layers = nn.ModuleList([
- InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
- self.gradient_checkpointing = True
-
- def forward(
- self,
- inputs_embeds,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Embedded representation of the inputs. Should be float, not int tokens.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- encoder_states = () if output_hidden_states else None
- hidden_states = inputs_embeds
-
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- layer_outputs = torch.utils.checkpoint.checkpoint(
- encoder_layer,
- hidden_states)
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- )
- hidden_states = layer_outputs
-
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
-
- if not return_dict:
- return tuple(v for v in [hidden_states, encoder_states] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states
- )
-
-
-class InternVisionModel(PreTrainedModel):
- main_input_name = 'pixel_values'
- _supports_flash_attn_2 = True
- config_class = InternVisionConfig
- _no_split_modules = ['InternVisionEncoderLayer']
-
- def __init__(self, config: InternVisionConfig):
- super().__init__(config)
- self.config = config
-
- self.embeddings = InternVisionEmbeddings(config)
- self.encoder = InternVisionEncoder(config)
-
- def resize_pos_embeddings(self, old_size, new_size, patch_size):
- pos_emb = self.embeddings.position_embedding
- _, num_positions, embed_dim = pos_emb.shape
- cls_emb = pos_emb[:, :1, :]
- pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
- pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
- pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
- pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
- self.embeddings.position_embedding = nn.Parameter(pos_emb)
- self.embeddings.image_size = new_size
- logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
-
- def get_input_embeddings(self):
- return self.embeddings
-
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- pixel_embeds: Optional[torch.FloatTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if pixel_values is None and pixel_embeds is None:
- raise ValueError('You have to specify pixel_values or pixel_embeds')
-
- if pixel_embeds is not None:
- hidden_states = pixel_embeds
- else:
- if len(pixel_values.shape) == 4:
- hidden_states = self.embeddings(pixel_values)
- else:
- raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
- encoder_outputs = self.encoder(
- inputs_embeds=hidden_states,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- pooled_output = last_hidden_state[:, 0, :]
-
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
-
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
diff --git a/code/xtuner/_lite/modelings/llava/__init__.py b/code/xtuner/_lite/modelings/llava/__init__.py
deleted file mode 100644
index 036324005b01497a392f8755e6ea582867bf1562..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .configuration_llava import EnhancedLlavaConfig
-from .modeling_llava import LlavaForConditionalGeneration
-from .processing_llava import LlavaProcessor
diff --git a/code/xtuner/_lite/modelings/llava/configuration_internlm2.py b/code/xtuner/_lite/modelings/llava/configuration_internlm2.py
deleted file mode 100644
index 8b810794707c07014b2143b5fced9fe0c621c8eb..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/configuration_internlm2.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" InternLM2 model configuration"""
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
-
-
-# Modified from transformers.model.llama.configuration_llama.LlamaConfig
-class InternLM2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*, defaults to 32000):
- Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`InternLM2Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 11008):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer decoder.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 2048):
- The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*):
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 1):
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 2):
- End of stream token id.
- pretraining_tp (`int`, *optional*, defaults to 1):
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
- document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
- to understand more about it. This value is necessary to ensure exact reproducibility
- of the pretraining results. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232).
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
- `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
- these scaling strategies behave:
- https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
- experimental feature, subject to breaking API changes in future versions.
- """
- _auto_class = 'AutoConfig'
- model_type = 'internlm2'
- keys_to_ignore_at_inference = ['past_key_values']
-
- def __init__( # pylint: disable=W0102
- self,
- vocab_size=103168,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=None,
- hidden_act='silu',
- max_position_embeddings=2048,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- pretraining_tp=1,
- tie_word_embeddings=False,
- bias=True,
- rope_theta=10000,
- rope_scaling=None,
- attn_implementation=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.bias = bias
-
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
-
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.pretraining_tp = pretraining_tp
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self._rope_scaling_validation()
- self.attn_implementation = attn_implementation
- if self.attn_implementation is None:
- self.attn_implementation = 'eager'
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
- def _rope_scaling_validation(self):
- """
- Validate the `rope_scaling` configuration.
- """
- if self.rope_scaling is None:
- return
-
- if not isinstance(self.rope_scaling,
- dict) or len(self.rope_scaling) != 2:
- raise ValueError(
- '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
- f'got {self.rope_scaling}')
- rope_scaling_type = self.rope_scaling.get('type', None)
- rope_scaling_factor = self.rope_scaling.get('factor', None)
- if rope_scaling_type is None or rope_scaling_type not in [
- 'linear', 'dynamic'
- ]:
- raise ValueError(
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
- )
- if (rope_scaling_factor is None
- or not isinstance(rope_scaling_factor,
- (float, int)) or rope_scaling_factor < 1.0):
- raise ValueError(
- f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
- f'of type {type(rope_scaling_factor)}')
diff --git a/code/xtuner/_lite/modelings/llava/configuration_llava.py b/code/xtuner/_lite/modelings/llava/configuration_llava.py
deleted file mode 100644
index f5ec7bbfaec7ea6ea625af5258a1abd62cc9a8d5..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/configuration_llava.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# coding=utf-8
-# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Llava model configuration"""
-import os
-from typing import Union
-from transformers.configuration_utils import PretrainedConfig, custom_object_save
-from transformers.utils import logging
-from transformers import CONFIG_MAPPING, AutoModelForCausalLM, AutoConfig
-
-logger = logging.get_logger(__name__)
-
-class EnhancedLlavaConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
- Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the Llava-9B.
-
- e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
- Args:
- vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
- The config object or dictionary of the vision backbone.
- text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
- The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
- image_token_index (`int`, *optional*, defaults to 32000):
- The image token index to encode the image prompt.
- projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
- The activation function used by the multimodal projector.
- vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`.
- vision_feature_layer (`int`, *optional*, defaults to -2):
- The index of the layer to select the vision feature.
-
- Example:
-
- ```python
- >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
-
- >>> # Initializing a CLIP-vision config
- >>> vision_config = CLIPVisionConfig()
-
- >>> # Initializing a Llama config
- >>> text_config = LlamaConfig()
-
- >>> # Initializing a Llava llava-1.5-7b style configuration
- >>> configuration = LlavaConfig(vision_config, text_config)
-
- >>> # Initializing a model from the llava-1.5-7b style configuration
- >>> model = LlavaForConditionalGeneration(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- _auto_class = 'AutoConfig'
- model_type = "enhanced_llava"
- is_composition = False
-
- def __init__(
- self,
- vision_config=None,
- text_config=None,
- ignore_index=-100,
- image_token_index=32000,
- projector_hidden_act="gelu",
- vision_feature_select_strategy="default",
- vision_feature_layer=-2,
- **kwargs,
- ):
- self.ignore_index = ignore_index
- self.image_token_index = image_token_index
- self.projector_hidden_act = projector_hidden_act
-
- if vision_feature_select_strategy not in ["default", "full"]:
- raise ValueError(
- "vision_feature_select_strategy should be one of 'default', 'full'."
- f"Got: {vision_feature_select_strategy}"
- )
-
- self.vision_feature_select_strategy = vision_feature_select_strategy
- self.vision_feature_layer = vision_feature_layer
-
- if isinstance(vision_config, dict):
- vision_config["model_type"] = (
- vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
- )
- vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
- elif vision_config is None:
- vision_config = CONFIG_MAPPING["clip_vision_model"](
- intermediate_size=4096,
- hidden_size=1024,
- patch_size=14,
- image_size=336,
- num_hidden_layers=24,
- num_attention_heads=16,
- vocab_size=32000,
- projection_dim=768,
- )
-
- self.vision_config = vision_config
-
- if isinstance(text_config, dict):
- text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
-
- if text_config["model_type"] == 'internlm2':
- from .configuration_internlm2 import InternLM2Config
- from .modeling_internlm2 import InternLM2ForCausalLM
- AutoConfig.register('internlm2', InternLM2Config)
- AutoModelForCausalLM.register(
- InternLM2Config, InternLM2ForCausalLM)
- text_config['auto_map']['AutoConfig'] = 'configuration_internlm2.InternLM2Config'
- text_config['auto_map']['AutoModel'] = 'modeling_internlm2.InternLM2ForCausalLM'
- text_config['auto_map']['AutoModelForCausalLM'] = 'modeling_internlm2.InternLM2ForCausalLM'
- text_config = InternLM2Config(**text_config)
- else:
- text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
-
- elif text_config is None:
- text_config = CONFIG_MAPPING["llama"]()
-
- self.text_config = text_config
-
- super().__init__(**kwargs)
-
-
- def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
- """
- Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
- [`~PretrainedConfig.from_pretrained`] class method.
-
- Args:
- save_directory (`str` or `os.PathLike`):
- Directory where the configuration JSON file will be saved (will be created if it does not exist).
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- kwargs (`Dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- """
- super().save_pretrained(save_directory, push_to_hub, **kwargs)
-
- if self.text_config._auto_class is not None:
- custom_object_save(self.text_config, save_directory, config=self.text_config)
-
-AutoConfig.register('enhanced_llava', EnhancedLlavaConfig, exist_ok=True)
\ No newline at end of file
diff --git a/code/xtuner/_lite/modelings/llava/modeling_internlm2.py b/code/xtuner/_lite/modelings/llava/modeling_internlm2.py
deleted file mode 100644
index 69ddc61969e34c8ebe387e0dc8dfd1d85db576cf..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/modeling_internlm2.py
+++ /dev/null
@@ -1,1899 +0,0 @@
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch InternLM2.5 model."""
-import math
-import queue
-import threading
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from einops import rearrange
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache, StaticCache
-from transformers.modeling_attn_mask_utils import AttentionMaskConverter
-from transformers.modeling_outputs import (BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutputWithPast,
- TokenClassifierOutput)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
-from transformers.utils import (add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_greater_or_equal_2_10, logging,
- replace_return_docstrings)
-
-try:
- from transformers.generation.streamers import BaseStreamer
-except Exception:
- BaseStreamer = None
-
-from .configuration_internlm2 import InternLM2Config
-
-try:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import (index_first_axis, pad_input,
- unpad_input)
-except:
- pass
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = 'InternLM2Config'
-
-
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-class InternLM2RMSNorm(nn.Module):
- """InternLM2RMSNorm is equivalent to T5LayerNorm."""
-
- def __init__(self, hidden_size, eps=1e-6):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance +
- self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
-
-
-class InternLM2RotaryEmbedding(nn.Module):
- """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0):
- super().__init__()
- self.scaling_factor = scaling_factor
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (
- self.base
- **(torch.arange(0, self.dim, 2,
- dtype=torch.int64).float().to(device) / self.dim))
- self.register_buffer('inv_freq', inv_freq, persistent=False)
- # For BC we register cos and sin cached
- self.max_seq_len_cached = max_position_embeddings
-
- @torch.no_grad()
- def forward(self, x, position_ids):
- # x: [bs, num_attention_heads, seq_len, head_size]
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
- position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 since bfloat16 loses precision on long contexts
- # See https://github.com/huggingface/transformers/pull/29285
- device_type = x.device.type
- device_type = device_type if isinstance(
- device_type, str) and device_type != 'mps' else 'cpu'
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float()
- @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
-
- def forward(self, x, position_ids):
- # difference to the original RoPE: a scaling factor is aplied to the position ids
- position_ids = position_ids.float() / self.scaling_factor
- cos, sin = super().forward(x, position_ids)
- return cos, sin
-
-
-class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
- Credits to the Reddit users /u/bloc97 and /u/emozilla"""
-
- def forward(self, x, position_ids):
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_position_embeddings:
- base = self.base * ((self.scaling_factor * seq_len /
- self.max_position_embeddings) -
- (self.scaling_factor - 1))**(
- self.dim / (self.dim - 2))
- inv_freq = 1.0 / (
- base
- **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(
- x.device) / self.dim))
- self.register_buffer(
- 'inv_freq', inv_freq,
- persistent=False) # TODO joao: this may break with compilation
-
- cos, sin = super().forward(x, position_ids)
- return cos, sin
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class InternLM2MLP(nn.Module):
- """MLP for InternLM2 model."""
-
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.w1 = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.w3 = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.w2 = nn.Linear(
- self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
-
- return down_proj
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-class InternLM2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self,
- config: InternLM2Config,
- layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
- 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
- 'when creating this class.')
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
- f' and `num_heads`: {self.num_heads}).')
-
- self.wqkv = nn.Linear(
- self.hidden_size,
- (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
- bias=config.bias,
- )
- self.wo = nn.Linear(
- self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
-
- self._init_rope()
-
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = InternLM2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling['type']
- scaling_factor = self.config.rope_scaling['factor']
- if scaling_type == 'linear':
- self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == 'dynamic':
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- else:
- raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False, # pylint: disable=unused-argument
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- if self.config.pretraining_tp > 1:
- # split qkv_states by tp size
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
- qkv_states = torch.cat(
- [
- F.linear(hidden_states, qkv_slice)
- for qkv_slice in qkv_slices
- ],
- dim=-1 # pylint: disable=E1102
- )
- else:
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states,
- 'b q h gs d -> b q (h gs) d').transpose(1, 2)
- key_states = qkv_states[..., -2, :].transpose(1, 2)
- value_states = qkv_states[..., -1, :].transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(
- 2, 3)) / math.sqrt(self.head_dim)
-
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(
- attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
- f' {attn_output.size()}')
-
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.wo.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class InternLM2FlashAttention2(InternLM2Attention):
- """
- InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
- # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
- # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
- # produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if isinstance(past_key_value, StaticCache):
- raise ValueError(
- '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
- 'make sure to use `sdpa` in the mean time, and open an issue at '
- 'https://github.com/huggingface/transformers')
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout
- # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # dropout_rate = self.attention_dropout if self.training else 0.0
- dropout_rate = 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (InternLM2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.wqkv.weight.dtype
-
- logger.warning_once(
- f'The input hidden states seems to be silently casted in float32, this might be related to'
- f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
- f' {target_dtype}.')
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate)
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.hidden_size).contiguous()
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value # pylint: disable=E0606
-
- def _flash_attention_forward(self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`float`):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- """
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
- # For details, please see the comment in InternLM2FlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask,
- query_length)
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
- query_length) # pylint: disable=E0606
- else:
- attn_output = flash_attn_func( # pylint: disable=E0606
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal)
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
- query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
- attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis( # pylint: disable=E0606
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- value_layer = index_first_axis( # pylint: disable=E0606
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- if query_length == kv_seq_len:
- query_layer = index_first_axis( # pylint: disable=E0606
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
- head_dim), indices_k)
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
- query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
-class InternLM2SdpaAttention(InternLM2Attention):
- """
- InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
- `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
- to adapt to SDPA API.
- """
-
- # Adapted from InternLM2Attention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
- # once this is implemented.
- logger.warning_once(
- 'InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` '
- 'does not support `output_attentions=True`. Falling back to the manual attention implementation, '
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. '
- 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- causal_mask = attention_mask
- if attention_mask is not None:
- causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
- # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == 'cuda' and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
- # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
- # options. An inline conditional prevents dynamic shapes from compiling.
- is_causal = bool(causal_mask is None and q_len > 1)
-
- attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=0.0,
- is_causal=is_causal,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
-
- attn_output = self.wo(attn_output)
-
- return attn_output, None, past_key_value
-
-
-INTERNLM2_ATTENTION_CLASSES = {
- 'eager': InternLM2Attention,
- 'flash_attention_2': InternLM2FlashAttention2,
- 'sdpa': InternLM2SdpaAttention,
-}
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
-class InternLM2DecoderLayer(nn.Module):
- """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
-
- def __init__(self, config: InternLM2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
-
- self.attention = INTERNLM2_ATTENTION_CLASSES[
- config.attn_implementation](
- config=config, layer_idx=layer_idx)
-
- self.feed_forward = InternLM2MLP(config)
- self.attention_norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
- self.ffn_norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
- torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- residual = hidden_states
-
- hidden_states = self.attention_norm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.attention(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.ffn_norm(hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states, )
-
- if output_attentions:
- outputs += (self_attn_weights, )
-
- if use_cache:
- outputs += (present_key_value, )
-
- return outputs
-
-
-InternLM2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`InternLM2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
-@add_start_docstrings(
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
- InternLM2_START_DOCSTRING,
-)
-class InternLM2PreTrainedModel(PreTrainedModel):
- """
- InternLM2 pretraiend model's base class.
- """
-
- config_class = InternLM2Config
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
- _no_split_modules = ['InternLM2DecoderLayer']
- _skip_keys_device_placement = ['past_key_values']
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-InternLM2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance;
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
-"""
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
-@add_start_docstrings(
- 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
- InternLM2_START_DOCSTRING,
-)
-class InternLM2Model(InternLM2PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
- Args:
- config: InternLM2Config
- """
-
- _auto_class = 'AutoModel'
-
- def __init__(self, config: InternLM2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.config = config
-
- self.tok_embeddings = nn.Embedding(config.vocab_size,
- config.hidden_size,
- self.padding_idx)
-
- self.layers = nn.ModuleList([
- InternLM2DecoderLayer(config, layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self.norm = InternLM2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- 'You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one'
- )
-
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.'
- )
- use_cache = False
-
- if inputs_embeds is None:
- inputs_embeds = self.tok_embeddings(input_ids)
-
- return_legacy_cache = False
- if use_cache and not isinstance(
- past_key_values,
- Cache): # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = True
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length(
- ) if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens,
- past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device)
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds,
- cache_position, past_key_values,
- output_attentions)
-
- # embed positions
- hidden_states = inputs_embeds
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[
- 2 if output_attentions else 1]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1], )
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
-
- if not return_dict:
- return tuple(
- v for v in
- [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
- # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
- # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
- # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
- # See more context in https://github.com/huggingface/transformers/pull/29114
-
- if self.config.attn_implementation == 'flash_attention_2':
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length(
- ) if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config.attn_implementation == 'sdpa' and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype, device = input_tensor.dtype, input_tensor.device
- min_dtype = torch.finfo(dtype).min
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_length()
- else:
- target_length = (
- attention_mask.shape[-1] if isinstance(
- attention_mask, torch.Tensor) else past_seen_tokens +
- sequence_length + 1)
-
- if attention_mask is not None and attention_mask.dim() == 4:
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
- if attention_mask.max() != 0:
- raise ValueError(
- 'Custom 4D attention mask should be passed in inverted form with max==0`'
- )
- causal_mask = attention_mask
- else:
- causal_mask = torch.full((sequence_length, target_length),
- fill_value=min_dtype,
- dtype=dtype,
- device=device)
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(
- target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(
- input_tensor.shape[0], 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone(
- ) # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :
- mask_length] + attention_mask[:,
- None,
- None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :
- mask_length] = causal_mask[:, :, :, :
- mask_length].masked_fill(
- padding_mask,
- min_dtype)
- if (self.config.attn_implementation == 'sdpa'
- and attention_mask is not None
- and attention_mask.device.type == 'cuda'
- and not output_attentions):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- causal_mask = AttentionMaskConverter._unmask_unattended(
- causal_mask, min_dtype) # pylint: disable=E1120
-
- return causal_mask
-
-
-# Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
-class InternLM2ForCausalLM(InternLM2PreTrainedModel):
- """Causal language model (CLM) for InternLM2."""
-
- _auto_class = 'AutoModelForCausalLM'
- _tied_weights_keys = ['output.weight']
-
- def __init__(self, config):
- super().__init__(config)
- self.model = InternLM2Model(config)
- self.vocab_size = config.vocab_size
- self.output = nn.Linear(
- config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- def get_output_embeddings(self):
- return self.output
-
- def set_output_embeddings(self, new_embeddings):
- self.output = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- @replace_return_docstrings(
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
- >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
-
- hidden_states = outputs[0]
- if self.config.pretraining_tp > 1:
- output_slices = self.output.weight.split(
- self.vocab_size // self.config.pretraining_tp, dim=0)
- logits = [
- F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
- for i in range(self.config.pretraining_tp)
- ]
- logits = torch.cat(logits, dim=-1)
- else:
- logits = self.output(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits, ) + outputs[1:]
- return (loss, ) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- cache_position=None,
- use_cache=True,
- **kwargs,
- ):
- past_length = 0
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- past_length = cache_position[
- 0] if cache_position is not None else past_key_values.get_seq_length(
- )
- max_cache_length = (
- torch.tensor(
- past_key_values.get_max_length(),
- device=input_ids.device)
- if past_key_values.get_max_length() is not None else None)
- cache_length = past_length if max_cache_length is None else torch.min(
- max_cache_length, past_length)
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
- if attention_mask is not None and attention_mask.shape[
- 1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] -
- past_length):]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
-
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
- if (max_cache_length is not None and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length):
- attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
-
- position_ids = kwargs.get('position_ids', None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1]:]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {'inputs_embeds': inputs_embeds}
- else:
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
- # recompiles graphs as the stride of the inputs is a guard.
- # Ref: https://github.com/huggingface/transformers/pull/29114
- # TODO: use `next_tokens` directly instead.
- model_inputs = {'input_ids': input_ids.contiguous()}
-
- input_length = position_ids.shape[
- -1] if position_ids is not None else input_ids.shape[-1]
- if cache_position is None:
- cache_position = torch.arange(
- past_length,
- past_length + input_length,
- device=input_ids.device)
- elif use_cache:
- cache_position = cache_position[-input_length:]
-
- model_inputs.update({
- 'position_ids': position_ids,
- 'cache_position': cache_position,
- 'past_key_values': past_key_values,
- 'use_cache': use_cache,
- 'attention_mask': attention_mask,
- })
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (tuple(
- past_state.index_select(0, beam_idx.to(past_state.device))
- for past_state in layer_past), )
- return reordered_past
-
- def build_inputs(self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- meta_instruction=''):
- if history is None:
- history = []
- if tokenizer.add_bos_token:
- prompt = ''
- else:
- prompt = tokenizer.bos_token
- if meta_instruction:
- prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
- for record in history:
- prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
- prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
- return tokenizer([prompt], return_tensors='pt')
-
- @torch.no_grad()
- def chat(
- self,
- tokenizer,
- query: str,
- history: Optional[List[Tuple[str, str]]] = None,
- streamer: Optional[BaseStreamer] = None,
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- meta_instruction:
- str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
- '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory '
- '(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
- '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such '
- 'as English and 中文.',
- **kwargs,
- ):
- if history is None:
- history = []
- inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
- inputs = {
- k: v.to(self.device)
- for k, v in inputs.items() if torch.is_tensor(v)
- }
- # also add end-of-assistant token in eos token id to avoid unnecessary generation
- eos_token_id = [
- tokenizer.eos_token_id,
- tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]
- ]
- outputs = self.generate(
- **inputs,
- streamer=streamer,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- eos_token_id=eos_token_id,
- **kwargs,
- )
- outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
- response = tokenizer.decode(outputs, skip_special_tokens=True)
- response = response.split('<|im_end|>')[0]
- history = history + [(query, response)]
- return response, history
-
- @torch.no_grad()
- def stream_chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = None,
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- **kwargs,
- ):
- if history is None:
- history = []
- """
- Return a generator in format: (response, history)
- Eg.
- ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
- ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
- """
- if BaseStreamer is None:
- raise ModuleNotFoundError(
- 'The version of `transformers` is too low. Please make sure '
- 'that you have installed `transformers>=4.28.0`.')
-
- response_queue = queue.Queue(maxsize=20)
-
- class ChatStreamer(BaseStreamer):
- """
- Streamer used in generate to print words one by one.
- """
-
- def __init__(self, tokenizer) -> None:
- super().__init__()
- self.tokenizer = tokenizer
- self.queue = response_queue
- self.query = query
- self.history = history
- self.response = ''
- self.cache = []
- self.received_inputs = False
- self.queue.put(
- (self.response, history + [(self.query, self.response)]))
-
- def put(self, value):
- if len(value.shape) > 1 and value.shape[0] > 1:
- raise ValueError('ChatStreamer only supports batch size 1')
- elif len(value.shape) > 1:
- value = value[0]
-
- if not self.received_inputs:
- # The first received value is input_ids, ignore here
- self.received_inputs = True
- return
-
- self.cache.extend(value.tolist())
- token = self.tokenizer.decode(
- self.cache, skip_special_tokens=True)
- if token.strip() != '<|im_end|>':
- self.response = self.response + token
- history = self.history + [(self.query, self.response)]
- self.queue.put((self.response, history))
- self.cache = []
- else:
- self.end()
-
- def end(self):
- self.queue.put(None)
-
- def stream_producer():
- return self.chat(
- tokenizer=tokenizer,
- query=query,
- streamer=ChatStreamer(tokenizer=tokenizer),
- history=history,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- **kwargs,
- )
-
- def consumer():
- producer = threading.Thread(target=stream_producer)
- producer.start()
- while True:
- res = response_queue.get()
- if res is None:
- return
- yield res
-
- return consumer()
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
-@add_start_docstrings(
- """
- The InternLM2 Model transformer with a sequence classification head on top (linear layer).
- [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
- """Sequence Classification Head for InternLM2 Model."""
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = InternLM2Model(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError(
- 'Cannot handle batch sizes > 1 if no padding token is defined.'
- )
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
- sequence_lengths = torch.eq(
- input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
- sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = 'regression'
- elif self.num_labels > 1 and (labels.dtype
- in (torch.long, torch.int)):
- self.config.problem_type = 'single_label_classification'
- else:
- self.config.problem_type = 'multi_label_classification'
-
- if self.config.problem_type == 'regression':
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == 'single_label_classification':
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == 'multi_label_classification':
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits, ) + transformer_outputs[1:]
- return ((loss, ) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
-@add_start_docstrings(
- """
-The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
-SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
- """Question Answering model for InternLM2."""
-
- base_model_prefix = 'transformer'
-
- def __init__(self, config):
- super().__init__(config)
- self.transformer = InternLM2Model(config)
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.transformer.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.transformer.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- start_positions: Optional[torch.LongTensor] = None,
- end_positions: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
- r"""
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.transformer(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- sequence_output = outputs[0]
-
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
-
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1).to(
- start_logits.device)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
-
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
-
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((total_loss, ) +
- output) if total_loss is not None else output
-
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
-@add_start_docstrings(
- """
- The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
- output) e.g. for Named-Entity-Recognition (NER) tasks.
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
- """Token classification model for InternLM2."""
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = InternLM2Model(config)
- if getattr(config, 'classifier_dropout', None) is not None:
- classifier_dropout = config.classifier_dropout
- elif getattr(config, 'hidden_dropout', None) is not None:
- classifier_dropout = config.hidden_dropout
- else:
- classifier_dropout = 0.1
- self.dropout = nn.Dropout(classifier_dropout)
- self.score = nn.Linear(config.hidden_size, config.num_labels)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.score(sequence_output)
-
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
-
- if not return_dict:
- output = (logits, ) + outputs[2:]
- return ((loss, ) + output) if loss is not None else output
-
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
diff --git a/code/xtuner/_lite/modelings/llava/modeling_llava.py b/code/xtuner/_lite/modelings/llava/modeling_llava.py
deleted file mode 100644
index b987db7b5382d248a8000cea2fc87176adcee44f..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/modeling_llava.py
+++ /dev/null
@@ -1,573 +0,0 @@
-# coding=utf-8
-# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch Llava model."""
-
-from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from transformers import PreTrainedModel
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache
-from transformers.modeling_outputs import ModelOutput
-from transformers.utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
-)
-from transformers import AutoModel, AutoModelForCausalLM
-from .configuration_llava import EnhancedLlavaConfig
-
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = "LlavaConfig"
-
-
-
-@dataclass
-# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
-class LlavaCausalLMOutputWithPast(ModelOutput):
- """
- Base class for Llava causal language model (or autoregressive) outputs.
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
-
- image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
-
-
-class LlavaMultiModalProjector(nn.Module):
- def __init__(self, config: EnhancedLlavaConfig):
- super().__init__()
-
- self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
- self.act = ACT2FN[config.projector_hidden_act]
- self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
-
- def forward(self, image_features):
- hidden_states = self.linear_1(image_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
-
-
-LLAVA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
- LLAVA_START_DOCSTRING,
-)
-class LlavaPreTrainedModel(PreTrainedModel):
- config_class = EnhancedLlavaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LlavaVisionAttention"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
-
- def _init_weights(self, module):
- # important: this ported version of Llava isn't meant for training from scratch - only
- # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
- # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
- std = (
- self.config.initializer_range
- if hasattr(self.config, "initializer_range")
- else self.config.text_config.initializer_range
- )
-
- if hasattr(module, "class_embedding"):
- module.class_embedding.data.normal_(mean=0.0, std=std)
-
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @property
- def _supports_sdpa(self):
- """
- Retrieve language_model's attribute to check whether the model supports
- SDPA or not.
- """
- return self.language_model._supports_sdpa
-
-
-LLAVA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
- The tensors corresponding to the input images. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
- [`CLIPImageProcessor`] for processing images).
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- vision_feature_layer (`int`, *optional*, defaults to -2):
- The index of the layer to select the vision feature.
- vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Can be one of `"default"` or `"full"`.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
- """The LLAVA model which consists of a vision backbone and a language model.""",
- LLAVA_START_DOCSTRING,
-)
-class LlavaForConditionalGeneration(LlavaPreTrainedModel):
-
- _auto_class = 'AutoModel'
-
- def __init__(self, config: EnhancedLlavaConfig):
- super().__init__(config)
- self.vision_tower = AutoModel.from_config(config.vision_config)
-
- self.multi_modal_projector = LlavaMultiModalProjector(config)
- self.vocab_size = config.text_config.vocab_size
- self.language_model = AutoModelForCausalLM.from_config(
- config.text_config,
- attn_implementation=config._attn_implementation)
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
- self.post_init()
-
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
-
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
-
- def get_output_embeddings(self):
- return self.language_model.get_output_embeddings()
-
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
-
- def set_decoder(self, decoder):
- self.language_model.set_decoder(decoder)
-
- def get_decoder(self):
- return self.language_model.get_decoder()
-
- def tie_weights(self):
- return self.language_model.tie_weights()
-
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
- # update vocab size
- self.config.text_config.vocab_size = model_embeds.num_embeddings
- self.vocab_size = model_embeds.num_embeddings
- return model_embeds
-
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
- num_images, num_image_patches, embed_dim = image_features.shape
- batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == self.config.image_token_index
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # Compute the maximum embed dimension
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
- batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
-
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
- if left_padding:
- new_token_positions += nb_image_pad[:, None] # offset for left padding
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- if labels is not None:
- final_labels = torch.full(
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- if labels is not None:
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
-
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
- raise ValueError(
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
- )
-
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
-
- final_embedding[batch_indices, indices_to_mask] = 0
-
- if labels is None:
- final_labels = None
-
- return final_embedding, final_attention_mask, final_labels, position_ids
-
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- pixel_values: torch.FloatTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- vision_feature_layer: Optional[int] = None,
- vision_feature_select_strategy: Optional[str] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- vision_feature_layer = (
- vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
- )
- vision_feature_select_strategy = (
- vision_feature_select_strategy
- if vision_feature_select_strategy is not None
- else self.config.vision_feature_select_strategy
- )
-
- if inputs_embeds is None:
- # 1. Extra the input embeddings
- inputs_embeds = self.get_input_embeddings()(input_ids)
-
- # ------------- start add this ----------------
- if pixel_values is None and self.training:
- # all of the input is text
- # If not handled properly, deadlock can occur.
- # print('===================all of the input is text==============')
- image_size = self.config.vision_config.image_size
- pixel_values = torch.zeros(input_ids.shape[0], 3, image_size, image_size,
- dtype=torch.float32,
- device=input_ids.device)
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
- if vision_feature_select_strategy == "default":
- selected_image_feature = selected_image_feature[:, 1:]
- elif vision_feature_select_strategy == "full":
- selected_image_feature = selected_image_feature
- else:
- raise ValueError(
- f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
- )
- image_features = self.multi_modal_projector(selected_image_feature)
- inputs_embeds = inputs_embeds.to(image_features.dtype)
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
- image_features[0:0], inputs_embeds, input_ids, attention_mask, labels
- )
- # ------------- end add this ----------------
- # 2. Merge text and images
- elif pixel_values is not None and input_ids.shape[1] != 1:
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
- # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
- selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
-
- if vision_feature_select_strategy == "default":
- selected_image_feature = selected_image_feature[:, 1:]
- elif vision_feature_select_strategy == "full":
- selected_image_feature = selected_image_feature
- else:
- raise ValueError(
- f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
- )
-
- image_features = self.multi_modal_projector(selected_image_feature)
- inputs_embeds = inputs_embeds.to(image_features.dtype)
- inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
- image_features, inputs_embeds, input_ids, attention_mask, labels
- )
-
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
- # generation with cache
- elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
- # Retrieve the first layer to inspect the logits and mask out the hidden states
- # that are set to 0
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
-
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
-
- # Get the target length
- target_length = input_ids.shape[1]
- past_length = first_layer_past_key_value.shape[-1]
-
- extended_attention_mask = torch.ones(
- (attention_mask.shape[0], past_length),
- dtype=attention_mask.dtype,
- device=attention_mask.device,
- )
-
- # Filter out only the tokens that can be un-attended, this can happen
- # if one uses Llava + Fused modules where the cache on the
- # first iteration is already big enough, or if one passes custom cache
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
- new_batch_index = batch_index[valid_indices]
- new_non_attended_tokens = non_attended_tokens[valid_indices]
-
- # Zero-out the places where we don't need to attend
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
-
- attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
-
- outputs = self.language_model(
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- logits = outputs[0]
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- shift_attention_mask = attention_mask[..., 1:]
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return LlavaCausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
- ):
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- cache_length = past_key_values.get_seq_length()
- past_length = past_key_values.seen_tokens
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
- # input)
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
- elif self.config.image_token_index in input_ids:
- input_ids = input_ids[:, input_ids.shape[1] - 1 :]
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
- # older attention values, as their corresponding values are not part of the input.
- if cache_length < past_length and attention_mask is not None:
- attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
-
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- "pixel_values": pixel_values,
- }
- )
- return model_inputs
-
- def _reorder_cache(self, *args, **kwargs):
- return self.language_model._reorder_cache(*args, **kwargs)
-
-AutoModel.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True)
-AutoModelForCausalLM.register(EnhancedLlavaConfig, LlavaForConditionalGeneration, exist_ok=True)
\ No newline at end of file
diff --git a/code/xtuner/_lite/modelings/llava/processing_llava.py b/code/xtuner/_lite/modelings/llava/processing_llava.py
deleted file mode 100644
index 2309755757968aa2cb1ee31fc815343a0f36ee3e..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/modelings/llava/processing_llava.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# coding=utf-8
-# Copyright 2023 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Processor class for Llava.
-"""
-
-from typing import List, Optional, Union
-
-from transformers.feature_extraction_utils import BatchFeature
-from transformers.image_utils import ImageInput
-from transformers.processing_utils import ProcessorMixin
-from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
-from transformers.utils import TensorType
-
-
-class LlavaProcessor(ProcessorMixin):
- r"""
- Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
-
- [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
- [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
-
- Args:
- image_processor ([`CLIPImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`LlamaTokenizerFast`], *optional*):
- The tokenizer is a required input.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- """
-
- attributes = ["image_processor", "tokenizer"]
- valid_kwargs = ["chat_template"]
- image_processor_class = "AutoImageProcessor"
- tokenizer_class = "AutoTokenizer"
-
- def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
-
- def __call__(
- self,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
- images: ImageInput = None,
- padding: Union[bool, str, PaddingStrategy] = False,
- truncation: Union[bool, str, TruncationStrategy] = None,
- max_length=None,
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
- and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
- of the above two methods for more information.
-
- Args:
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. Both channels-first and channels-last formats are supported.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`, *optional*):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
-
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- if images is not None:
- image_inputs = self.image_processor(images, return_tensors=return_tensors)
- else:
- image_inputs = {}
- text_inputs = self.tokenizer(
- text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
- )
-
- return BatchFeature(data={**text_inputs, **image_inputs})
-
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
-
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
-
- @property
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
\ No newline at end of file
diff --git a/code/xtuner/_lite/parallel/__init__.py b/code/xtuner/_lite/parallel/__init__.py
deleted file mode 100644
index af975873b3c5a0209d427f1f33a8098183d2da65..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .comm import all_to_all, all_to_all_list, barrier
-from .sampler import LengthGroupedSampler, ParallelSampler, VLMLengthGroupedSampler
-from .sequence import * # noqa: F401, F403
-from .setup import setup_parallel
-
-__all__ = [
- "ParallelSampler",
- "LengthGroupedSampler",
- "VLMLengthGroupedSampler",
- "all_to_all",
- "all_to_all_list",
- "setup_parallel",
- "barrier",
-]
diff --git a/code/xtuner/_lite/parallel/comm.py b/code/xtuner/_lite/parallel/comm.py
deleted file mode 100644
index 29bd6ab533c701cdee8c02d077b0232199a4fefd..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/comm.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Tuple
-
-import torch
-import torch.distributed as dist
-from torch import Tensor
-from torch.distributed.distributed_c10d import (
- _get_pg_default_device,
- _object_to_tensor,
- _tensor_to_object,
-)
-
-
-# Modified from https://github.com/microsoft/DeepSpeed/blob/ffd0a0e3ef24bfd00c2e5f35019d2674cc01ec14/deepspeed/sequence/layer.py#L15 # noqa: E501
-def _all_to_all(
- input: Tensor,
- world_size: int,
- group: dist.ProcessGroup,
- scatter_dim: int,
- gather_dim: int,
-):
- input_list = [
- t.contiguous() for t in torch.tensor_split(input, world_size, scatter_dim)
- ]
- output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
- dist.all_to_all(output_list, input_list, group=group)
- return torch.cat(output_list, dim=gather_dim).contiguous()
-
-
-class _AllToAll(torch.autograd.Function):
- """All-to-all communication.
-
- Args:
- input: Input tensor
- sp_group: Sequence parallel process group
- scatter_dim: Scatter dimension
- gather_dim: Gather dimension
- """
-
- @staticmethod
- def forward(
- ctx: Any,
- input: Tensor,
- sp_group: dist.ProcessGroup,
- scatter_dim: int,
- gather_dim: int,
- ):
- ctx.sp_group = sp_group
- ctx.scatter_dim = scatter_dim
- ctx.gather_dim = gather_dim
- ctx.world_size = dist.get_world_size(sp_group)
- output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim, gather_dim)
- return output
-
- @staticmethod
- def backward(ctx: Any, grad_output: Tensor) -> Tuple:
- grad_output = _all_to_all(
- grad_output,
- ctx.world_size,
- ctx.sp_group,
- ctx.gather_dim,
- ctx.scatter_dim,
- )
- return (
- grad_output,
- None,
- None,
- None,
- )
-
-
-def all_to_all(
- input: Tensor,
- sp_group: dist.ProcessGroup,
- scatter_dim: int = 2,
- gather_dim: int = 1,
-):
- """Convenience function to apply the all-to-all operation with scatter and
- gather dimensions.
-
- Notes:
- We have wrapped the `torch.distributed.all_to_all` function to
- enable automatic differentiation of the all-to-all operation.
-
- Args:
- input: The input tensor for which all-to-all communication is performed
- sp_group: The sequence parallel process group.
- scatter_dim: The dimension along which the input tensor is scattered
- (default: 2).
- gather_dim: The dimension along which the output tensor is gathered
- (default: 1).
-
- Returns:
- The output tensor after the all-to-all communication.
- """
- return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
-
-
-def all_to_all_list(object_list, group=None):
- current_device = _get_pg_default_device(group)
- rank = dist.get_rank(group)
- world_size = dist.get_world_size(group)
- tensor_list, size_list = zip(
- *[_object_to_tensor(obj, current_device, group) for obj in object_list]
- )
- tensor_list = list(tensor_list)
- size_list = torch.cat(size_list)
- buffer = [None] * world_size
-
- dist.all_gather_object(buffer, size_list, group=group)
- size_this_rank = []
- for size_list in buffer:
- size_this_rank.append(size_list[rank])
-
- target_tensor_list = [
- torch.empty(size.item(), dtype=torch.uint8, device=current_device)
- for size in size_this_rank
- ]
- dist.all_to_all(target_tensor_list, tensor_list, group=group)
-
- for i in range(len(target_tensor_list)):
- obj_view = target_tensor_list[i].type(torch.uint8)
- target_tensor_list[i] = _tensor_to_object(obj_view, size_this_rank[i], group)
-
- return target_tensor_list
-
-
-def barrier():
- if not dist.is_available():
- return
-
- rank = dist.get_rank()
- if rank == 0:
- objects = [1]
- else:
- objects = [None]
-
- dist.broadcast_object_list(objects, src=0)
- return
diff --git a/code/xtuner/_lite/parallel/sampler.py b/code/xtuner/_lite/parallel/sampler.py
deleted file mode 100644
index 00464ad65d63c5800bc1f5ec764947b6f749029e..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/sampler.py
+++ /dev/null
@@ -1,397 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import random
-from typing import Iterator, Optional, Sized
-
-import torch
-from mmengine.dist import sync_random_seed
-from torch.distributed.device_mesh import DeviceMesh
-from torch.utils.data import ConcatDataset as TorchConcatDataset
-from torch.utils.data import Sampler
-
-
-class ParallelSampler(Sampler):
- """The default data sampler for both distributed and non-distributed
- environment.
-
- It has several differences from the PyTorch ``DistributedSampler`` as
- below:
-
- 1. This sampler supports non-distributed environment.
-
- 2. The round up behaviors are a little different.
-
- - If ``round_up=True``, this sampler will add extra samples to make the
- number of samples is evenly divisible by the world size. And
- this behavior is the same as the ``DistributedSampler`` with
- ``drop_last=False``.
- - If ``round_up=False``, this sampler won't remove or add any samples
- while the ``DistributedSampler`` with ``drop_last=True`` will remove
- tail samples.
-
- Args:
- dataset (Sized): The dataset.
- shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
- seed (int, optional): Random seed used to shuffle the sampler if
- :attr:`shuffle=True`. This number should be identical across all
- processes in the distributed group. Defaults to None.
- round_up (bool): Whether to add extra samples to make the number of
- samples evenly divisible by the world size. Defaults to True.
- """
-
- def __init__(
- self,
- dataset: Sized,
- dp_mesh: DeviceMesh,
- global_batch_size: int,
- shuffle: bool = True,
- seed: Optional[int] = None,
- round_up: bool = True,
- ) -> None:
- rank = dp_mesh.get_local_rank()
- world_size = dp_mesh.size()
-
- assert global_batch_size % world_size == 0
- self.global_batch_size = global_batch_size
- self.rank = rank
- self.world_size = world_size
-
- self.dataset = dataset
- self.shuffle = shuffle
- if seed is None:
- seed = sync_random_seed()
- self.seed = seed
- self.epoch = 0
- self.step = 0
- self.round_up = round_up
-
- if self.round_up:
- self.num_samples = (
- math.ceil(len(self.dataset) / global_batch_size)
- * global_batch_size
- // world_size
- )
- self.total_size = self.num_samples * self.world_size
- else:
- self.num_samples = math.ceil((len(self.dataset) - rank) / world_size)
- self.total_size = len(self.dataset)
-
- def __iter__(self) -> Iterator[int]:
- """Iterate the indices."""
- # deterministically shuffle based on epoch and seed
- if self.shuffle:
- g = torch.Generator()
- g.manual_seed(self.seed + self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
- else:
- indices = torch.arange(len(self.dataset)).tolist()
-
- # add extra samples to make it evenly divisible
- if self.round_up:
- indices = (indices * int(self.total_size / len(indices) + 1))[
- : self.total_size
- ]
-
- # subsample
- indices = indices[self.rank : self.total_size : self.world_size]
-
- return iter(indices[self.step :])
-
- def __len__(self) -> int:
- """The number of samples in this rank."""
- return self.num_samples - self.step
-
- def set_epoch(self, epoch: int, step=0) -> None:
- """Sets the epoch for this sampler.
-
- When :attr:`shuffle=True`, this ensures all replicas use a different
- random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
-
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
- self.step = step
-
-
-def get_length_grouped_indices(max_lengths, group_batch_size, dp_size, seed=None):
- if seed is not None:
- torch.manual_seed(seed)
- random.seed(seed)
-
- assert all(leng != 0 for leng in max_lengths), "Should not have zero length."
- indices = torch.randperm(len(max_lengths))
- megabatches = [
- indices[i : i + group_batch_size].tolist()
- for i in range(0, len(max_lengths), group_batch_size)
- ]
- output = []
- for megabatch in megabatches:
- megabatch = sorted(megabatch, key=lambda i: max_lengths[i], reverse=True)
- grouped_megabatch = [
- megabatch[i : i + dp_size] for i in range(0, len(megabatch), dp_size)
- ]
- random.shuffle(grouped_megabatch)
- for group in grouped_megabatch:
- output.extend(group)
-
- return output
-
-
-class LengthGroupedSampler(Sampler):
- def __init__(
- self,
- dataset: Sized,
- dp_mesh: DeviceMesh,
- global_batch_size: int,
- length_attr: str = "longest",
- mega_batch_mult: Optional[int] = None,
- seed: Optional[int] = None,
- round_up: bool = True,
- ) -> None:
- rank = dp_mesh.get_local_rank()
- world_size = dp_mesh.size()
- self.rank = rank
- self.world_size = world_size
- assert global_batch_size % world_size == 0
-
- self.dataset = dataset
- if seed is None:
- seed = sync_random_seed()
- self.seed = seed
- self.epoch = 0
- self.step = 0
- self.round_up = round_up
-
- if self.round_up:
- self.num_samples = (
- math.ceil(len(self.dataset) / global_batch_size)
- * global_batch_size
- // world_size
- )
- self.total_size = self.num_samples * self.world_size
- else:
- self.num_samples = math.ceil((len(self.dataset) - rank) / world_size)
- self.total_size = len(self.dataset)
-
- if mega_batch_mult is None:
- # Default for mega_batch_mult: 50 or the number to get 4
- # megabatches, whichever is smaller.
- mega_batch_mult = min(len(self.dataset) // (global_batch_size * 4), 50)
- # Just in case, for tiny datasets
- if mega_batch_mult == 0:
- mega_batch_mult = 1
- self.group_batch_size = mega_batch_mult * global_batch_size
-
- if isinstance(self.dataset, TorchConcatDataset):
- max_lengths = []
- for sub_dataset in self.dataset.datasets:
- if hasattr(sub_dataset, length_attr):
- max_lengths.extend(getattr(sub_dataset, length_attr))
- else:
- raise ValueError
- self.max_lengths = max_lengths
- else:
- if hasattr(self.dataset, length_attr):
- self.max_lengths = getattr(self.dataset, length_attr)
- assert isinstance(self.max_lengths, (list, tuple))
-
- self.global_batch_size = global_batch_size
-
- def __iter__(self) -> Iterator[int]:
- """Iterate the indices."""
- generator = torch.Generator()
- generator.manual_seed(self.seed + self.epoch)
- seed = self.seed + self.epoch
- indices = get_length_grouped_indices(
- max_lengths=self.max_lengths,
- group_batch_size=self.group_batch_size,
- dp_size=self.world_size,
- seed=seed,
- )
- assert len(set(indices)) == len(indices)
- # add extra samples to make it evenly divisible
- if self.round_up:
- indices = (indices * int(self.total_size / len(indices) + 1))[
- : self.total_size
- ]
- # subsample
- assert len(indices) == self.total_size
- indices = indices[self.rank : self.total_size : self.world_size]
- assert len(indices) == self.num_samples
- return iter(indices[self.step :])
-
- def __len__(self) -> int:
- """The number of samples in this rank."""
- return self.num_samples - self.step
-
- def set_epoch(self, epoch: int, step=0) -> None:
- """Sets the epoch for this sampler.
-
- When :attr:`shuffle=True`, this ensures all replicas use a different
- random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
-
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
- self.step = step
-
-
-def vlm_get_length_grouped_indices(
- max_lengths, group_batch_size, generator=None, **kwargs
-):
- def process(lengths, group_batch_size, generator=None):
- indices = torch.randperm(len(lengths), generator=generator)
- megabatches = [
- indices[i : i + group_batch_size].tolist()
- for i in range(0, len(lengths), group_batch_size)
- ]
- megabatches = [
- sorted(megabatch, key=lambda i: lengths[i], reverse=True)
- for megabatch in megabatches
- ]
- return megabatches
-
- lengths = max_lengths
- assert all(leng != 0 for leng in lengths), "Should not have zero length."
- if all(leng > 0 for leng in lengths) or all(leng < 0 for leng in lengths):
- # all samples are in the same modality
- megabatches = process(lengths, group_batch_size, generator=generator)
- else:
- mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
- lang_indices, lang_lengths = zip(
- *[(i, -l) for i, l in enumerate(lengths) if l < 0]
- )
- mm_megabatches = []
- for mm_megabatch in process(mm_lengths, group_batch_size, generator=generator):
- mm_megabatches.append([mm_indices[i] for i in mm_megabatch])
- lang_megabatches = []
- for lang_megabatch in process(
- lang_lengths, group_batch_size, generator=generator
- ):
- lang_megabatches.append([lang_indices[i] for i in lang_megabatch])
-
- last_mm = mm_megabatches[-1]
- last_lang = lang_megabatches[-1]
- last_batch = last_mm + last_lang
- megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
-
- megabatch_indices = torch.randperm(len(megabatches), generator=generator)
- megabatches = [megabatches[i] for i in megabatch_indices]
-
- if len(last_batch) > 0:
- megabatches.append(
- sorted(last_batch, key=lambda i: abs(lengths[i]), reverse=True)
- )
-
- # The rest is to get the biggest batch first.
- # Since each megabatch is sorted by descending length,
- # the longest element is the first
- megabatch_maximums = [abs(lengths[megabatch[0]]) for megabatch in megabatches]
- max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
- # Switch to put the longest element in first position
- megabatches[0][0], megabatches[max_idx][0] = (
- megabatches[max_idx][0],
- megabatches[0][0],
- )
-
- return [i for megabatch in megabatches for i in megabatch]
-
-
-class VLMLengthGroupedSampler(Sampler):
- def __init__(
- self,
- dataset: Sized,
- dp_mesh: DeviceMesh,
- global_batch_size: int,
- mega_batch_mult: Optional[int] = None,
- seed: Optional[int] = None,
- round_up: bool = True,
- length_property="length",
- ) -> None:
- rank = dp_mesh.get_local_rank()
- world_size = dp_mesh.size()
- self.rank = rank
- self.world_size = world_size
- assert global_batch_size % world_size == 0
-
- self.dataset = dataset
- if seed is None:
- seed = sync_random_seed()
- self.seed = seed
- self.epoch = 0
- self.step = 0
- self.round_up = round_up
-
- if self.round_up:
- self.num_samples = (
- math.ceil(len(self.dataset) / global_batch_size)
- * global_batch_size
- // world_size
- )
- self.total_size = self.num_samples * self.world_size
- else:
- self.num_samples = math.ceil((len(self.dataset) - rank) / world_size)
- self.total_size = len(self.dataset)
-
- if mega_batch_mult is None:
- # Default for mega_batch_mult: 50 or the number to get 4
- # megabatches, whichever is smaller.
- mega_batch_mult = min(len(self.dataset) // (global_batch_size * 4), 50)
- # Just in case, for tiny datasets
- if mega_batch_mult == 0:
- mega_batch_mult = 1
- self.group_batch_size = mega_batch_mult * global_batch_size
-
- if isinstance(self.dataset, TorchConcatDataset):
- max_lengths = []
- for sub_dataset in self.dataset.datasets:
- max_lengths.extend(getattr(sub_dataset, length_property))
- self.max_lengths = max_lengths
- else:
- self.max_lengths = getattr(self.dataset, length_property)
- assert isinstance(self.max_lengths, (list, tuple))
-
- self.global_batch_size = global_batch_size
-
- def __iter__(self) -> Iterator[int]:
- """Iterate the indices."""
- generator = torch.Generator()
- generator.manual_seed(self.seed + self.epoch)
- indices = vlm_get_length_grouped_indices(
- max_lengths=self.max_lengths,
- group_batch_size=self.group_batch_size,
- dp_size=self.world_size,
- generator=generator,
- )
- assert len(set(indices)) == len(indices)
- # add extra samples to make it evenly divisible
- if self.round_up:
- indices = (indices * int(self.total_size / len(indices) + 1))[
- : self.total_size
- ]
- # subsample
- assert len(indices) == self.total_size
- indices = indices[self.rank : self.total_size : self.world_size]
- assert len(indices) == self.num_samples
- return iter(indices[self.step :])
-
- def __len__(self) -> int:
- """The number of samples in this rank."""
- return self.num_samples - self.step
-
- def set_epoch(self, epoch: int, step=0) -> None:
- """Sets the epoch for this sampler.
-
- When :attr:`shuffle=True`, this ensures all replicas use a different
- random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
-
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
- self.step = step
diff --git a/code/xtuner/_lite/parallel/sequence/__init__.py b/code/xtuner/_lite/parallel/sequence/__init__.py
deleted file mode 100644
index f3b771bd58e3c8ccd56b304b698e28dd5df35c2e..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/sequence/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.dist import init_dist
-
-from .attention import (
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn,
-)
-from .ops import (
- gather_for_sequence_parallel,
- gather_forward_split_backward,
- split_for_sequence_parallel,
- split_forward_gather_backward,
-)
-
-__all__ = [
- "pre_process_for_sequence_parallel_attn",
- "post_process_for_sequence_parallel_attn",
- "split_for_sequence_parallel",
- "init_dist",
- "gather_for_sequence_parallel",
- "split_forward_gather_backward",
- "gather_forward_split_backward",
-]
diff --git a/code/xtuner/_lite/parallel/sequence/attention.py b/code/xtuner/_lite/parallel/sequence/attention.py
deleted file mode 100644
index f53ba7c6795523321f3f7966ad2f20fe0a62f099..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/sequence/attention.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from torch.distributed.device_mesh import DeviceMesh
-
-from ..comm import all_to_all
-
-
-def pre_process_for_sequence_parallel_attn(
- query_states: torch.Tensor,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- sp_mesh: DeviceMesh,
- scatter_dim: int = 2,
- gather_dim: int = 1,
-):
- sp_size = sp_mesh.size()
- n_head = query_states.shape[2]
- assert n_head % sp_size == 0, (
- "The number of attention heads should be divisible by "
- f"sequence_parallel_world_size. But got n_head = {n_head} and "
- f"sequence_parallel_world_size = {sp_size}."
- )
-
- # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim)
- sp_group = sp_mesh.get_group()
- query_states = all_to_all(
- query_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim
- )
- key_states = all_to_all(
- key_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim
- )
- value_states = all_to_all(
- value_states, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim
- )
-
- return query_states, key_states, value_states
-
-
-def post_process_for_sequence_parallel_attn(
- attn_output: torch.Tensor, sp_mesh: DeviceMesh, scatter_dim=1, gather_dim=2
-):
- # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
- sp_group = sp_mesh.get_group()
- output = all_to_all(
- attn_output, sp_group, scatter_dim=scatter_dim, gather_dim=gather_dim
- )
- return output
diff --git a/code/xtuner/_lite/parallel/sequence/ops.py b/code/xtuner/_lite/parallel/sequence/ops.py
deleted file mode 100644
index c5421ee263bd882d81df342bf39fad8831503d36..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/sequence/ops.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.distributed as dist
-
-
-def split_for_sequence_parallel(input, dim: int, sp_mesh):
- """Splits the input tensor along a given dimension for sequence parallel.
-
- Args:
- input: The input tensor to be split.
- dim: The dimension along which the tensor should be split.
- sp_group: The sequence parallel process group.
-
- Returns:
- The split tensor corresponding to the current rank's chunk.
- """
- sp_group = sp_mesh.get_group()
- sp_size = sp_mesh.size()
- if sp_size == 1:
- return input
-
- rank = dist.get_rank(sp_group)
- dim_size = input.size(dim)
- assert dim_size % sp_size == 0, (
- f"The dimension to split ({dim_size}) is not a multiple of "
- f"sp size ({sp_size}), cannot split tensor evenly"
- )
-
- tensor_list = torch.split(input, dim_size // sp_size, dim=dim)
- output = tensor_list[rank].contiguous()
-
- return output
-
-
-def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
- """Gathers the input tensor along a given dimension for sequence parallel.
-
- Args:
- input: The input tensor to be gathered.
- dim: The dimension along which the tensor should be gathered.
- sp_group: The sequence parallel process group.
-
- Returns:
- The gathered tensor concatenated along the specified dimension.
- """
- input = input.contiguous()
- world_size = dist.get_world_size(sp_group)
- dist.get_rank(sp_group)
-
- if world_size == 1:
- return input
-
- tensor_list = [torch.empty_like(input) for _ in range(world_size)]
- assert input.device.type == "cuda"
- dist.all_gather(tensor_list, input, group=sp_group)
-
- output = torch.cat(tensor_list, dim=dim).contiguous()
-
- return output
-
-
-class _GatherForwardSplitBackward(torch.autograd.Function):
- """Gather the input during forward.
-
- Scale and split the grad and keep only the corresponding chuck to the rank during backward.
- """
-
- @staticmethod
- def forward(ctx, input, dim, sp_group, grad_scale):
- ctx.dim = dim
- ctx.sp_group = sp_group
- ctx.grad_scale = grad_scale
- return gather_for_sequence_parallel(input, dim, sp_group)
-
- @staticmethod
- def backward(ctx, grad_output):
- if ctx.grad_scale == "up":
- grad_output = grad_output * dist.get_world_size(ctx.sp_group)
- elif ctx.grad_scale == "down":
- grad_output = grad_output / dist.get_world_size(ctx.sp_group)
-
- return (
- split_for_sequence_parallel(grad_output, ctx.dim, ctx.sp_group),
- None,
- None,
- None,
- )
-
-
-class _SplitForwardGatherBackward(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank during
- forward.
-
- Scale and gather the grad during backward.
- """
-
- @staticmethod
- def forward(ctx, input, dim, sp_group, grad_scale):
- ctx.dim = dim
- ctx.sp_group = sp_group
- ctx.grad_scale = grad_scale
- return split_for_sequence_parallel(input, dim, sp_group)
-
- @staticmethod
- def backward(ctx, grad_output):
- if ctx.grad_scale == "up":
- grad_output = grad_output * dist.get_world_size(ctx.sp_group)
- elif ctx.grad_scale == "down":
- grad_output = grad_output / dist.get_world_size(ctx.sp_group)
- return (
- gather_for_sequence_parallel(grad_output, ctx.dim, ctx.sp_group),
- None,
- None,
- None,
- )
-
-
-def split_forward_gather_backward(input, dim, sp_group, grad_scale=None):
- """Split tensors according to the sp rank during forward propagation and
- gather the grad from the whole sp group during backward propagation.
-
- 1. When do we need this? input.requires_grad = True
-
- 2. Why we need grad scale?
-
- We have to scale down the grads as `gather_forward_split_backward` scales
- up the grads.
- """
- return _SplitForwardGatherBackward.apply(input, dim, sp_group, grad_scale)
-
-
-def gather_forward_split_backward(input, dim, sp_group, grad_scale=None):
- """Gather tensors from the whole sp group during forward propagation and
- split the grad according to the sp rank during backward propagation.
-
- 1. When do we need this?
-
- When sp is greater than 1, we need to slice the input `x` along
- sequence length dimension before it is passed into the model and get
- `sub_seq_x`. We then pass `sub_seq_x` into model and get output
- `sub_seq_out`. If the loss calculation process needs to use the complete
- output, we have to gather the `sub_seq_out` in all sp ranks during forward
- propagation and split the grad during backward propagation.
-
- 2. Why we need grad scale?
- Here is a simple case.
-
- -------- SP 1 -----------
- Suppose here is a toy model with only one linear module
- (in_features = 2, out_features = 1) and the input x has shape(2, 2).
- Y = [[y1], = [[w11x11 + w21x12], = [[x11, x12], dot [[w11],
- [y2]] [w11x21 + w21x22]] [x21, x22]] [w21]]
- z = mean(Y) = (y1 + y2) / 2
- Here is the partial derivative of z with respect to w11:
- ∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 + ∂z / ∂y2 * ∂y2 / ∂w11
- = 1/2 * x11 + 1/2 * x21 = (x11 + x21) / 2
-
- -------- SP 2 -----------
- When sequence parallel world size is set to 2, we will split the input x
- and scatter them to the two rank in the same sequence parallel group.
- ```Step 1
- Y_rank0 = [[y1]] = [[w11x11 + w21x12]] = [[x11, x12]] dot [[w11, w21]]^T
- Y_rank1 = [[y2]] = [[w11x21 + w21x22]] = [[x21, x22]] dot [[w11, w21]]^T
- ```
-
- Then, we have to gather them:
- ```Step 2
- Y_rank0 = [[y1],
- detach([y2])]
- Y_rank1 = [detach([y1]),
- [y2]]
- ```
- Note that y2 in Y_rank0 does not have grad, neither does y1 in Y_rank1.
-
- Similarly, we calculate the loss in each rank:
- ```Step 3
- z_rank0 = mean(Y_rank0) = (y1 + detach(y2)) / 2
- z_rank1 = mean(Y_rank1) = (detach(y1) + y2) / 2
- ```
- So the partial derivative of loss_rank0 with respect to w11:
- ```∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 = x11 / 2```
- The same for rank1:
- ```∂z / ∂w11 = ∂z / ∂y2 * ∂y2 / ∂w11 = x21 / 2```
-
- Finally, we need to all_reduce them:
- ```Step 4
- In both rank:
- ∂z / ∂w11 = (x11 / 2 + x21 / 2) / 2 = (x11 + x21) / 4
- ```
-
- In SP2, the gradient of each param is only half of that in SP1.
- So we should scale up the grad during the backward process in Step 2.
- """ # noqa: E501
- return _GatherForwardSplitBackward.apply(input, dim, sp_group, grad_scale)
diff --git a/code/xtuner/_lite/parallel/setup.py b/code/xtuner/_lite/parallel/setup.py
deleted file mode 100644
index 55eeb78d6f6bdaf8d7d25d851d1ea12ed200cd5c..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/parallel/setup.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.distributed as dist
-from mmengine.dist import infer_launcher, init_dist
-from torch._C._distributed_c10d import ReduceOp
-from torch.distributed.c10d_logger import _exception_logger
-
-from xtuner._lite import get_device
-
-origin_reduce_scatter_tensor = torch.distributed.reduce_scatter_tensor
-
-
-# mlu's reduce_scatter_tensor do not support ReduceOp.AVG, use ReduceOp.SUM / group_world_size instead.
-@_exception_logger
-def mlu_reduce_scatter_tensor(
- output, input, op=ReduceOp.SUM, group=None, async_op=False
-):
- if op == ReduceOp.AVG:
- result = origin_reduce_scatter_tensor(
- output, input, ReduceOp.SUM, group, async_op
- )
- output.div_(torch.distributed.get_world_size(group))
- return result
- else:
- return origin_reduce_scatter_tensor(output, input, op, group, async_op)
-
-
-def setup_parallel():
- if not dist.is_initialized():
- dist_launcher = infer_launcher()
- init_dist(dist_launcher)
-
- device = get_device()
-
- if device == "mlu":
- torch.distributed.reduce_scatter_tensor = mlu_reduce_scatter_tensor
diff --git a/code/xtuner/_lite/patches/__init__.py b/code/xtuner/_lite/patches/__init__.py
deleted file mode 100644
index cec5cad9faaee1a4074e861220bac3387f5a5356..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .auto import AutoPatch
-from .base import FSDPConfig
-from .utils import pad_to_max_length, pad_to_multiple_of
-
-__all__ = ["AutoPatch", "FSDPConfig", "pad_to_max_length", "pad_to_multiple_of"]
diff --git a/code/xtuner/_lite/patches/auto.py b/code/xtuner/_lite/patches/auto.py
deleted file mode 100644
index 55ba748ee9ae10703cb8b7c4cd2dc5f63e05b1df..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/auto.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from transformers.models.llama import LlamaForCausalLM
-from transformers.models.qwen2 import Qwen2ForCausalLM
-
-from xtuner._lite.modelings.internlm3 import InternLM3ForCausalLM
-
-from .base import FSDPConfig, PatchedCausalLM
-from .internlm3 import (
- CUDAPatchedInternLM3ForCausalLM,
- MLUPatchedInternLM3ForCausalLM,
- MuxiPatchedInternLM3ForCausalLM,
-)
-from .llama import (
- CUDAPatchedLlamaForCausalLM,
- MLUPatchedLlamaForCausalLM,
- MuxiPatchedLlamaForCausalLM,
-)
-from .qwen2 import CUDAPatchedQwen2ForCausalLM
-
-CUDA_PATCH_MAP = {
- LlamaForCausalLM: CUDAPatchedLlamaForCausalLM,
- InternLM3ForCausalLM: CUDAPatchedInternLM3ForCausalLM,
- Qwen2ForCausalLM: CUDAPatchedQwen2ForCausalLM,
-}
-
-MLU_PATCH_MAP = {
- LlamaForCausalLM: MLUPatchedLlamaForCausalLM,
- InternLM3ForCausalLM: MLUPatchedInternLM3ForCausalLM,
-}
-
-MUXI_PATCH_MAP = {
- LlamaForCausalLM: MuxiPatchedLlamaForCausalLM,
- InternLM3ForCausalLM: MuxiPatchedInternLM3ForCausalLM,
-}
-
-
-class AutoPatch:
- @classmethod
- def from_causal_lm(
- cls, model, fsdp_config: FSDPConfig, device_type="cuda"
- ) -> PatchedCausalLM:
- if device_type == "cuda":
- patch_cls = CUDA_PATCH_MAP[type(model)]
- elif device_type == "mlu":
- patch_cls = MLU_PATCH_MAP[type(model)]
- elif device_type == "muxi":
- patch_cls = MUXI_PATCH_MAP[type(model)]
- else:
- raise NotImplementedError
-
- patched_model = patch_cls(model, fsdp_config)
-
- return patched_model
diff --git a/code/xtuner/_lite/patches/base.py b/code/xtuner/_lite/patches/base.py
deleted file mode 100644
index b332b8fca337deb3cbd93ae68b8fc2f4676f3cc2..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/base.py
+++ /dev/null
@@ -1,440 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import json
-import os
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
-
-import torch
-from accelerate.utils import set_module_tensor_to_device
-from safetensors import safe_open
-from torch import Tensor
-from torch import distributed as dist
-from torch import nn
-from torch.nn.utils.clip_grad import _no_grad
-from torch.utils._foreach_utils import (
- _device_has_foreach_support,
- _group_tensors_by_device_and_dtype,
- _has_foreach_support,
-)
-from transformers import PreTrainedModel
-
-from xtuner._lite import get_logger, get_torch_device_module
-
-logger = get_logger()
-
-DEVICE_MODULE = get_torch_device_module()
-
-
-@_no_grad
-def clip_grad_norm_(
- parameters,
- fsdp_mesh,
- max_norm: float,
- norm_type: float = 2.0,
- error_if_nonfinite: bool = False,
- foreach=None,
-) -> torch.Tensor:
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- grads = [p.grad for p in parameters if p.grad is not None]
- max_norm = float(max_norm)
- norm_type = float(norm_type)
- if len(grads) == 0:
- return torch.tensor(0.0)
- first_device = grads[0].device
-
- grouped_grads: Dict[
- Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
- ] = _group_tensors_by_device_and_dtype(
- [grads]
- ) # type: ignore[assignment]
-
- norms: List[Tensor] = []
- for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
- if (foreach is None and _has_foreach_support(device_grads, device)) or (
- foreach and _device_has_foreach_support(device)
- ):
- # for grouped_device_grads in group_tensors_by_device_mesh(device_grads).values():
- norms.extend(torch._foreach_norm(device_grads, norm_type))
- elif foreach:
- raise RuntimeError(
- f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
- )
- else:
- norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
-
- local_sharded_norm = torch.linalg.vector_norm(
- torch.stack([norm.to_local().to(first_device) for norm in norms]),
- norm_type,
- dtype=torch.float32,
- )
-
- if norm_type == 2:
- total_norm = local_sharded_norm**norm_type
- dist.all_reduce(total_norm, group=fsdp_mesh.get_group(mesh_dim=0))
- total_norm = total_norm ** (1 / norm_type)
- else:
- raise NotImplementedError
-
- if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
- raise RuntimeError(
- f"The total norm of order {norm_type} for gradients from "
- "`parameters` is non-finite, so it cannot be clipped. To disable "
- "this error and scale the gradients by the non-finite norm anyway, "
- "set `error_if_nonfinite=False`"
- )
- clip_coef = max_norm / (total_norm + 1e-6)
- # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
- # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
- # when the gradients do not reside in CPU memory.
- clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
- for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
- if (foreach is None and _has_foreach_support(device_grads, device)) or (
- foreach and _device_has_foreach_support(device)
- ):
- torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
- elif foreach:
- raise RuntimeError(
- f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
- )
- else:
- clip_coef_clamped_device = clip_coef_clamped.to(device)
- for g in device_grads:
- g.mul_(clip_coef_clamped_device.to(g.dtype))
-
- return total_norm
-
-
-def download_model_from_hub(
- model_name_or_path: str,
- from_hub: Literal["huggingface", "modelscope"] = "huggingface",
- cache_dir: Optional[str] = None,
-) -> str:
- """Automatically download model from the HUB.
-
- Note:
- If `model_name_or_path` is a local path, it will return the path
- directly without downloading it again.
-
- Args:
- model_name_or_path (str): The model name, model path or repo id.
- config (str | None): The config path. Default is None.
- from_hub (str): The model hosting hub, modelscope, or huggingface.
- Default is huggingface.
- cache_dir (str | None):
- The save path when downloading the model. If it is None, it
- will be stored in the default location of the HUB. For
- Huggingface, it's ~/.cache/huggingface/hub, for ModelScope,
- it's ~/.cache/modelscope/hub.
- Returns:
- str: The local path of the model.
- """
- if os.path.isdir(model_name_or_path):
- model_path = model_name_or_path
- elif from_hub == "huggingface":
- from huggingface_hub import snapshot_download
-
- model_path = snapshot_download(repo_id=model_name_or_path, cache_dir=cache_dir)
- elif from_hub == "modelscope":
- from modelscope import snapshot_download
-
- model_path = snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
- else:
- # TODO support openxlab
- raise NotImplementedError(
- "The model does not support downloading "
- f"from {from_hub}, it only supports "
- "`huggingface` and `modelscope`."
- )
-
- return model_path
-
-
-class HFCheckpointLoader:
- def __init__(self, model_path, cache_dir=None, from_hub="huggingface"):
- self.model_path = download_model_from_hub(model_path, from_hub, cache_dir)
-
- if "model.safetensors.index.json" in os.listdir(self.model_path):
- index_json = os.path.join(self.model_path, "model.safetensors.index.json")
- self.use_safetensors = True
- elif "model.bin.index.json" in os.listdir(self.model_path):
- index_json = os.path.join(self.model_path, "model.bin.index.json")
- self.use_safetensors = False
- else:
- raise FileNotFoundError
-
- with open(index_json) as f:
- self.weight_map = json.load(f)["weight_map"]
-
- self.current_file = None
- self.buffer = None
-
- def load(self, key):
- if key not in self.weight_map:
- logger.warning(f"{key} not in checkpoint.")
- return
-
- _file = self.weight_map[key]
-
- if self.use_safetensors:
- if self.current_file is None:
- self.buffer = safe_open(
- os.path.join(self.model_path, _file), framework="pt"
- )
- self.current_file = _file
-
- if _file != self.current_file:
- self.buffer = safe_open(
- os.path.join(self.model_path, _file), framework="pt"
- )
- self.current_file = _file
- weight = self.buffer.get_tensor(key)
-
- else:
- if self.current_file is None:
- self.buffer = torch.load(os.path.join(self.model_path, _file))
- self.current_file = _file
-
- if _file != self.current_file:
- self.buffer = torch.load(os.path.join(self.model_path, _file))
-
- weight = self.buffer[key]
-
- return weight
-
-
-@torch.no_grad
-def lazy_init_fn(module, module2name, checkpoint_loader):
- device = DEVICE_MODULE.current_device()
-
- module_name = module2name[module]
-
- params = {
- name: checkpoint_loader.load(f"{module_name}.{name}")
- for name, _ in module.named_parameters(recurse=False)
- }
-
- buffers = {
- name: checkpoint_loader.load(f"{module_name}.{name}")
- for name, _ in module.named_buffers(recurse=False)
- if f"{module_name}.{name}" in checkpoint_loader.weight_map
- }
-
- module.to_empty(device=DEVICE_MODULE.current_device(), recurse=False)
-
- for name, param in module.named_parameters(recurse=False):
- dtype = param.dtype
-
- if params[name] is None:
- continue
-
- _param = params[name].to(device).to(dtype)
-
- if param.shape == _param.shape:
- param.data.copy_(_param)
- else:
- logger.warning(
- f"The shape of {module_name}.{name}({param.shape}) "
- f"is inconsistent with that in the checkpoint({_param.shape}), "
- "it is initialized to 0 by default."
- )
- param.data.zero_()
-
- for name, buffer in module.named_buffers(recurse=False):
- if name in buffers:
- _buffer = buffers[name].to(device).to(buffer.dtype)
-
- if buffer.shape == _buffer.shape:
- buffer.data.copy_(_buffer)
- else:
- logger.warning(
- f"The shape of {module_name}.{name}({buffer.shape}) "
- f"is inconsistent with that in the checkpoint({_buffer.shape}), "
- "it is initialized to 0 by default."
- )
- buffer.data.zero_()
-
-
-@dataclass
-class FSDPConfig:
- tp_size: int = 1
- sp_size: int = 1
- ep_size: int = 1
- reshard_after_forward: bool = True
- recompute_ratio: float = 1.0
- cpu_offload: bool = True
- param_dtype: torch.dtype = torch.bfloat16
- reduce_dtype: torch.dtype = torch.bfloat16
- torch_compile: torch.dtype = False
- max_length: Optional[int] = None
- mesh_prefix: str = "default"
-
-
-@dataclass
-class ModelConfig:
- num_hidden_layers: int
- num_attention_heads: int
- num_key_value_heads: int
- head_dim: int
- hidden_size: int
- intermediate_size: int
- vocab_size: int
-
-
-class PatchedCausalLM(ABC, nn.Module):
- def __init__(self, model: PreTrainedModel, fsdp_config: FSDPConfig):
- super().__init__()
-
- @property
- @abstractmethod
- def rank0_model(self) -> Optional[PreTrainedModel]:
- pass
-
- @property
- @abstractmethod
- def patched_model(self) -> PreTrainedModel:
- pass
-
- @property
- @abstractmethod
- def fsdp_config(self) -> FSDPConfig:
- pass
-
- @property
- @abstractmethod
- def model_config(self) -> ModelConfig:
- pass
-
- @property
- @abstractmethod
- def data_parallel_mesh(self):
- pass
-
- @property
- @abstractmethod
- def data_mesh(self):
- pass
-
- @property
- @abstractmethod
- def sequence_parallel_mesh(self):
- pass
-
- @abstractmethod
- def dispatch_hf_code(self, model) -> PreTrainedModel:
- pass
-
- @abstractmethod
- def fully_shard(self, parallel_config: FSDPConfig):
- pass
-
- @abstractmethod
- def trainable_parameters(self) -> List[Dict[str, List[nn.Parameter]]]:
- pass
-
- @abstractmethod
- def clip_grad_norm(self, max_norm: float) -> torch.Tensor:
- pass
-
- def save_pretrained(
- self,
- save_directory: Union[str, os.PathLike],
- is_main_process: bool = True,
- state_dict: Optional[dict] = None,
- save_function: Callable = torch.save,
- push_to_hub: bool = False,
- max_shard_size: Union[int, str] = "5GB",
- safe_serialization: bool = True,
- variant: Optional[str] = None,
- token: Optional[Union[str, bool]] = None,
- save_peft_format: bool = True,
- **kwargs,
- ):
- if dist.is_initialized() and dist.is_available():
- rank = dist.get_rank()
- else:
- rank = 0
-
- from torch.distributed._tensor import DTensor
-
- dtype = self.patched_model.config.torch_dtype
- for name, param in self.patched_model.state_dict().items():
- if self.fsdp_config.torch_compile and "_orig_mod." in name:
- name = name.replace("_orig_mod.", "")
- if isinstance(param, DTensor):
- full_param = param.to(dtype).full_tensor().cpu()
- else:
- full_param = param.to(dtype).cpu()
-
- if rank == 0:
- set_module_tensor_to_device(self.rank0_model, name, "cpu", full_param)
-
- if rank == 0:
- self.rank0_model.save_pretrained(
- save_directory,
- is_main_process,
- state_dict,
- save_function,
- push_to_hub,
- max_shard_size,
- safe_serialization,
- variant,
- token,
- save_peft_format,
- **kwargs,
- )
-
- # def save_checkpoint(self,
- # optimizer: Optional[torch.optim.Optimizer] = None):
-
- # # FSDP cannot be saved via torch.save
- # # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html # noqa: E501
- # _options = StateDictOptions(
- # cpu_offload=True, ignore_frozen_params=True)
- # (shard_model_state_dict,
- # shard_optimizer_state_dict) = get_state_dict(
- # llm, optimizer, options=_options)
-
- # state_dict = {
- # 'model': shard_model_state_dict,
- # 'optimizer': shard_optimizer_state_dict,
- # 'train_state': train_state.state_dict(),
- # 'warmup_scheduler': warmup_scheduler.state_dict(),
- # 'cosine_scheduler': cosine_scheduler.state_dict()
- # }
-
- # mkdir_or_exist(ckpt_dir)
- # ckpt_handle = dcp.async_save(state_dict, checkpoint_id=ckpt_dir, process_group=gloo_group)
-
- # def load_checkpoint(self,
- # checkpoint_id: str,
- # optimizer: Optional[torch.optim.Optimizer] = None ):
- # _options = StateDictOptions(
- # cpu_offload=True, ignore_frozen_params=True)
- # (shard_model_state_dict,
- # shard_optimizer_state_dict) = get_state_dict(
- # patched_llm.patched_model, optimizer, options=_options)
- # state_dict = {
- # 'model': shard_model_state_dict,
- # 'optimizer': shard_optimizer_state_dict,
- # 'train_state': train_state,
- # 'warmup_scheduler': warmup_scheduler,
- # 'cosine_scheduler': cosine_scheduler
- # }
-
- # # inplace state_dict
- # dcp.load(
- # state_dict=state_dict,
- # checkpoint_id=latest_checkpoint,
- # )
-
- # _options = StateDictOptions(
- # cpu_offload=True, strict=False)
- # set_state_dict(
- # patched_llm.patched_model,
- # optimizer,
- # model_state_dict=state_dict["model"],
- # optim_state_dict=state_dict["optimizer"],
- # options=_options
- # )
diff --git a/code/xtuner/_lite/patches/internlm3.py b/code/xtuner/_lite/patches/internlm3.py
deleted file mode 100644
index 10024ece3c60311a318bad7373ce5d070d9758d7..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/internlm3.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner._lite.chat import HybridChatTemplate
-from xtuner._lite.modelings.internlm3.modeling_internlm3 import (
- InternLM3Attention,
- InternLM3DecoderLayer,
- InternLM3ForCausalLM,
- InternLM3RotaryEmbedding,
-)
-
-from .llama import CUDAPatchedLlamaForCausalLM
-
-
-class CUDAPatchedInternLM3ForCausalLM(CUDAPatchedLlamaForCausalLM):
- rotary_emb_cls = InternLM3RotaryEmbedding
- attn_cls = InternLM3Attention
- layer_cls = InternLM3DecoderLayer
- causal_cls = InternLM3ForCausalLM
-
- chat_template = HybridChatTemplate(
- system="<|im_start|>system\n{system}<|im_end|>\n",
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
- assistant="{assistant}<|im_end|>",
- stop_words=["<|im_end|>"],
- )
-
- def __init__(self, model, fsdp_config=None):
- super().__init__(model, fsdp_config)
-
- if fsdp_config.max_length is not None:
- self.patched_model.config.rope_scaling = {"rope_type": "default"}
- ori_max_len = self.patched_model.config.max_position_embeddings
- self.patched_model.config.max_position_embeddings = max(
- fsdp_config.max_length, ori_max_len
- )
- self.patched_model.model.rotary_emb = InternLM3RotaryEmbedding(
- self.patched_model.config
- ).to(self.device_type)
-
-
-class MLUPatchedInternLM3ForCausalLM(CUDAPatchedInternLM3ForCausalLM):
- device_type = "mlu"
-
-
-class MuxiPatchedInternLM3ForCausalLM(CUDAPatchedInternLM3ForCausalLM):
- device_type = "muxi"
diff --git a/code/xtuner/_lite/patches/llama.py b/code/xtuner/_lite/patches/llama.py
deleted file mode 100644
index aed4161dbb14380d3965c16d2b7ad0c579bf2b4f..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/llama.py
+++ /dev/null
@@ -1,1256 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import copy
-import types
-from functools import partial
-from typing import Callable, List, Optional, Tuple, TypedDict, Union
-
-import torch
-from flash_attn import flash_attn_with_kvcache
-from packaging import version
-from torch import distributed as dist
-from torch import nn
-from torch.distributed._composable.fsdp import (
- CPUOffloadPolicy,
- MixedPrecisionPolicy,
- fully_shard,
-)
-from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor
-from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
- checkpoint_wrapper,
-)
-from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
-from torch.distributed.tensor.parallel import (
- ColwiseParallel,
- PrepareModuleInput,
- RowwiseParallel,
- SequenceParallel,
- parallelize_module,
-)
-from torch.nn import functional as F
-from tqdm import tqdm
-from transformers.cache_utils import Cache
-from transformers.modeling_outputs import CausalLMOutputWithPast
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-from transformers.models.llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaRMSNorm,
- LlamaRotaryEmbedding,
- apply_rotary_pos_emb,
- eager_attention_forward,
- repeat_kv,
-)
-from transformers.processing_utils import Unpack
-from transformers.utils import logging
-
-from xtuner._lite.accelerate import liger_kernel_is_available
-from xtuner._lite.chat import HybridChatTemplate
-from xtuner._lite.parallel.sequence import split_for_sequence_parallel
-from xtuner._lite.patches.base import (
- FSDPConfig,
- HFCheckpointLoader,
- ModelConfig,
- PatchedCausalLM,
- clip_grad_norm_,
- lazy_init_fn,
-)
-from xtuner._lite.patches.mixins import GenerateMixin
-from xtuner._lite.patches.utils import pad_to_max_length, pad_to_multiple_of
-
-logger = logging.get_logger(__name__)
-
-
-def all_to_all(
- input: torch.Tensor,
- scatter_dim: int,
- gather_dim: int,
- mesh: DeviceMesh,
-) -> torch.Tensor:
- group = mesh.get_group()
- world_size = mesh.size()
- input_list = [
- t.contiguous() for t in torch.tensor_split(input, world_size, scatter_dim)
- ]
- output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
- dist.nn.all_to_all(output_list, input_list, group=group)
- return torch.cat(output_list, dim=gather_dim).contiguous()
-
-
-class FlashAttentionKwargs(TypedDict, total=False):
- """Keyword arguments for Flash Attention with Compile.
-
- Attributes:
- cu_seq_lens_q (`torch.LongTensor`, *optional*)
- Gets cumulative sequence length for query state.
- cu_seq_lens_k (`torch.LongTensor`, *optional*)
- Gets cumulative sequence length for key state.
- max_length_q (`int`, *optional*):
- Maximum sequence length for query state.
- max_length_k (`int`, *optional*):
- Maximum sequence length for key state.
- """
-
- cu_seq_lens_q: Optional[torch.LongTensor]
- cu_seq_lens_k: Optional[torch.LongTensor]
- max_length_q: Optional[int]
- max_length_k: Optional[int]
- block_table: Optional[torch.Tensor]
- prefilling: Optional[bool]
-
-
-@torch.library.custom_op("xtuner::fill_paged_kv_cache", mutates_args=())
-def fill_paged_kv_cache(
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- cu_seq_lens_q: torch.Tensor,
- cu_seq_lens_k: torch.Tensor,
- max_length_q: int,
- max_length_k: int,
- block_table: torch.Tensor,
-) -> None:
- bs = block_table.size(0)
- from lmdeploy.pytorch.kernels import fill_kv_cache
-
- fill_kv_cache(
- key_states.transpose(1, 2)[:, : cu_seq_lens_k[bs]],
- value_states.transpose(1, 2)[:, : cu_seq_lens_k[bs]],
- key_cache,
- value_cache,
- cu_seq_lens_q[:bs], # q_start_loc
- cu_seq_lens_q[1 : bs + 1] - cu_seq_lens_q[:bs], # q_seq_length
- kv_seq_length=cu_seq_lens_k[1 : bs + 1] - cu_seq_lens_k[:bs],
- max_q_seq_length=max_length_q,
- block_offsets=block_table,
- )
-
-
-@fill_paged_kv_cache.register_fake
-def fill_paged_kv_cache_fake(
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- cu_seq_lens_q: torch.Tensor,
- cu_seq_lens_k: torch.Tensor,
- max_length_q: int,
- max_length_k: int,
- block_table: torch.Tensor,
-) -> None:
- return None
-
-
-@torch.library.custom_op("xtuner::paged_attention_decoding", mutates_args=())
-def paged_attention_decoding(
- query_states: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- cache_seqlens: torch.Tensor,
- block_table: torch.Tensor,
-) -> torch.Tensor:
- bs = block_table.size(0)
- attn_outputs = flash_attn_with_kvcache(
- query_states.transpose(1, 2).transpose(0, 1)[:bs],
- key_cache,
- value_cache,
- cache_seqlens=cache_seqlens,
- block_table=block_table,
- causal=True,
- )
- return attn_outputs
-
-
-@paged_attention_decoding.register_fake
-def paged_attention_decoding_fake(
- query_states: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- cache_seqlens: torch.Tensor,
- block_table: torch.Tensor,
-):
- bs = block_table.size(0)
- return torch.empty_like(query_states.transpose(1, 2).transpose(0, 1)[:bs])
-
-
-class CUDAPatchedLlamaForCausalLM(PatchedCausalLM, GenerateMixin):
- device_type = "cuda"
- rotary_emb_cls = LlamaRotaryEmbedding
- attn_cls = LlamaAttention
- norm_cls = LlamaRMSNorm
- layer_cls = LlamaDecoderLayer
- causal_cls = LlamaForCausalLM
-
- layer_tp_plan = {
- "input_layernorm": SequenceParallel(),
- "self_attn": PrepareModuleInput(
- input_layouts=(Shard(1),),
- desired_input_layouts=(Replicate(),),
- ),
- "self_attn.q_proj": ColwiseParallel(),
- "self_attn.k_proj": ColwiseParallel(),
- "self_attn.v_proj": ColwiseParallel(),
- "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
- "post_attention_layernorm": SequenceParallel(),
- "mlp": PrepareModuleInput(
- input_layouts=(Shard(1),),
- desired_input_layouts=(Replicate(),),
- ),
- "mlp.up_proj": ColwiseParallel(),
- "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
- "mlp.gate_proj": ColwiseParallel(),
- }
-
- casual_tp_plan = {
- "model.embed_tokens": RowwiseParallel(
- input_layouts=Replicate(),
- output_layouts=Shard(1),
- ),
- "model.norm": PrepareModuleInput(
- input_layouts=(Replicate(),),
- desired_input_layouts=(Replicate(),),
- ),
- "lm_head": PrepareModuleInput(
- input_layouts=(Replicate(),),
- desired_input_layouts=(Replicate(),),
- ),
- }
-
- chat_template = HybridChatTemplate(
- system=("<|start_header_id|>system<|end_header_id|>\n\n{system}" "<|eot_id|>"),
- user=(
- "<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|>"
- "<|start_header_id|>assistant<|end_header_id|>\n\n"
- ),
- assistant="{assistant}<|eot_id|>",
- sep="",
- stop_words=["<|eot_id|>"],
- )
-
- def __init__(
- self, model: LlamaForCausalLM, fsdp_config: Optional[FSDPConfig] = None
- ):
- super().__init__(model, fsdp_config)
-
- if dist.is_initialized() and dist.is_available():
- rank = dist.get_rank()
- else:
- rank = 0
-
- if rank == 0:
- self._rank0_model = copy.deepcopy(model)
- else:
- self._rank0_model = None
-
- self._patched_model = self.dispatch_hf_code(model)
-
- self.init_model_config(fsdp_config)
-
- self._fsdp_config = fsdp_config
- if self._fsdp_config is not None:
- self.fully_shard(fsdp_config)
-
- @property
- def patched_model(self) -> LlamaForCausalLM:
- return self._patched_model
-
- @property
- def rank0_model(self) -> LlamaForCausalLM:
- return self._rank0_model
-
- @property
- def model_config(self) -> ModelConfig:
- return self._model_config
-
- @property
- def fsdp_config(self) -> Optional[FSDPConfig]:
- return self._fsdp_config
-
- @property
- def data_parallel_mesh(self):
- return self.dp_mesh
-
- @property
- def data_mesh(self):
- return self._data_mesh
-
- @property
- def sequence_parallel_mesh(self):
- return self.sp_mesh
-
- def init_model_config(self, fsdp_config: FSDPConfig):
- assert self.patched_model.config.num_key_value_heads >= fsdp_config.tp_size
- assert self.patched_model.config.num_key_value_heads % fsdp_config.tp_size == 0
-
- self._model_config = ModelConfig(
- num_hidden_layers=self.patched_model.config.num_hidden_layers,
- num_attention_heads=self.patched_model.config.num_attention_heads,
- num_key_value_heads=self.patched_model.config.num_key_value_heads
- // fsdp_config.tp_size,
- hidden_size=self.patched_model.config.hidden_size,
- intermediate_size=self.patched_model.config.intermediate_size,
- vocab_size=self.patched_model.config.vocab_size,
- head_dim=self.patched_model.config.head_dim,
- )
-
- @classmethod
- def dispatch_hf_code(cls, model) -> LlamaForCausalLM:
- for name, module in model.named_modules():
- if isinstance(module, cls.attn_cls):
- module.forward = types.MethodType(cls.patched_attn_forward, module)
- if isinstance(module, cls.causal_cls):
- module.forward = types.MethodType(cls.patched_casual_forward, module)
- if isinstance(module, cls.layer_cls):
- module.forward = types.MethodType(cls.patched_layer_forward, module)
-
- return model
-
- def fully_shard(self, fsdp_config: FSDPConfig) -> None:
- if fsdp_config.ep_size > 1:
- raise NotImplementedError
-
- world_size = dist.get_world_size()
- sp_size = fsdp_config.sp_size
- tp_size = fsdp_config.tp_size
-
- if tp_size > sp_size:
- # add warning
- pass
- elif tp_size < sp_size:
- assert sp_size % tp_size == 0
- sp_size = sp_size // tp_size
-
- assert world_size % sp_size == 0
- assert world_size % tp_size == 0
- world_mesh_name = f"{fsdp_config.mesh_prefix}.world"
- fsdp_mesh_name = f"{fsdp_config.mesh_prefix}.fsdp"
- tp_mesh_name = f"{fsdp_config.mesh_prefix}.tp"
- dp_mesh_name = f"{fsdp_config.mesh_prefix}.dp"
- sp_mesh_name = f"{fsdp_config.mesh_prefix}.sp"
- data_mesh_name = f"{fsdp_config.mesh_prefix}.data"
- _tp_mesh_name = f"{fsdp_config.mesh_prefix}._tp"
-
- world_mesh = init_device_mesh(
- self.device_type,
- (world_size,),
- mesh_dim_names=[
- world_mesh_name,
- ],
- )
- self.world_mesh = world_mesh[world_mesh_name]
-
- model_mesh = init_device_mesh(
- self.device_type,
- (world_size // tp_size, tp_size),
- mesh_dim_names=[fsdp_mesh_name, tp_mesh_name],
- )
-
- fsdp_mesh = model_mesh[fsdp_mesh_name]
- tp_mesh = model_mesh[tp_mesh_name]
-
- self.tp_mesh = tp_mesh
- self.fsdp_mesh = fsdp_mesh
-
- data_mesh = init_device_mesh(
- self.device_type,
- (world_size // tp_size // sp_size, sp_size, tp_size),
- mesh_dim_names=[dp_mesh_name, sp_mesh_name, _tp_mesh_name],
- )
- self.dp_mesh = data_mesh[dp_mesh_name]
- self.sp_mesh = data_mesh[sp_mesh_name]
-
- _data_mesh = init_device_mesh(
- self.device_type,
- (world_size // tp_size // sp_size, sp_size * tp_size),
- mesh_dim_names=[dp_mesh_name, data_mesh_name],
- )
- self._data_mesh = _data_mesh[data_mesh_name]
-
- param_init_fn = partial(
- lazy_init_fn,
- module2name={mod: name for name, mod in self.patched_model.named_modules()},
- checkpoint_loader=HFCheckpointLoader(
- self.patched_model.config._name_or_path
- ),
- )
-
- mp_policy = MixedPrecisionPolicy(
- param_dtype=fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
- )
-
- self.patched_model.model.rotary_emb = self.rotary_emb_cls(
- self.patched_model.config
- )
-
- num_recompute_layers = int(
- self.model_config.num_hidden_layers * fsdp_config.recompute_ratio
- )
-
- from torch.distributed._symmetric_memory import enable_symm_mem_for_group
-
- torch._inductor.config._micro_pipeline_tp = True
- enable_symm_mem_for_group(self.tp_mesh.get_group().group_name)
-
- if fsdp_config.torch_compile:
- compiled_layers = []
-
- for layer in tqdm(self.patched_model.model.layers):
- layer.apply(param_init_fn)
-
- attention = layer.self_attn
-
- if tp_mesh.size() > 1:
- parallelize_module(
- module=layer,
- device_mesh=tp_mesh,
- parallelize_plan=self.layer_tp_plan,
- )
-
- if attention.layer_idx < num_recompute_layers:
- layer = checkpoint_wrapper(layer, preserve_rng_state=False)
-
- if fsdp_config.torch_compile:
- layer = torch.compile(layer, fullgraph=True)
-
- self.patched_model.model.layers.register_module(
- str(attention.layer_idx), layer
- )
-
- fully_shard(
- layer,
- mesh=fsdp_mesh,
- mp_policy=mp_policy,
- reshard_after_forward=fsdp_config.reshard_after_forward,
- offload_policy=CPUOffloadPolicy() if fsdp_config.cpu_offload else None,
- )
-
- if fsdp_config.torch_compile:
- compiled_layers.append(layer)
-
- if version.parse(torch.__version__) >= version.parse("2.5.0"):
- for layer_cur, layer_next in zip(
- self.patched_model.model.layers[:-1],
- self.patched_model.model.layers[1:],
- ):
- layer_cur.set_modules_to_forward_prefetch([layer_next])
-
- self.patched_model.lm_head.apply(param_init_fn)
- self.patched_model.model.embed_tokens.apply(param_init_fn)
- self.patched_model.model.norm.apply(param_init_fn)
-
- if tp_mesh.size() > 1:
- _weight = self.patched_model.lm_head.weight
- _dtensor_weight = nn.Parameter(
- distribute_tensor(_weight, tp_mesh, [Replicate()])
- )
- self.patched_model.lm_head.register_parameter("weight", _dtensor_weight)
-
- _weight = self.patched_model.model.norm.weight
- _dtensor_weight = nn.Parameter(
- distribute_tensor(_weight, tp_mesh, [Replicate()])
- )
- self.patched_model.model.norm.register_parameter("weight", _dtensor_weight)
-
- parallelize_module(
- self.patched_model,
- tp_mesh,
- self.casual_tp_plan,
- )
-
- fully_shard(
- self.patched_model,
- mesh=fsdp_mesh,
- mp_policy=mp_policy,
- reshard_after_forward=fsdp_config.reshard_after_forward,
- offload_policy=CPUOffloadPolicy() if fsdp_config.cpu_offload else None,
- )
-
- @staticmethod
- def patched_attn_forward(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- position_ids: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Cache] = None,
- use_cache: Optional[bool] = True,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if "block_table" in kwargs and kwargs["block_table"] is not None:
- # generating
- if "prefilling" in kwargs and kwargs["prefilling"]:
- return CUDAPatchedLlamaForCausalLM.patched_attn_prefilling(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
- else:
- return CUDAPatchedLlamaForCausalLM.patched_attn_decoding(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
- else:
- return CUDAPatchedLlamaForCausalLM.patched_attn_forward_training(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
-
- @staticmethod
- def patched_attn_forward_training(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- position_ids: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin
- )
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs
- )
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and output_attentions:
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
- "Falling back to eager attention. This warning can be removed using the argument "
- '`attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[
- self.config._attn_implementation
- ]
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- sp_size = sequence_parallel_mesh.size()
- num_kv_heads = key_states.size(1)
- if sp_size > num_kv_heads:
- assert sp_size % num_kv_heads == 0
- key_states = repeat_kv(key_states, sp_size // num_kv_heads)
- value_states = repeat_kv(value_states, sp_size // num_kv_heads)
-
- query_states = all_to_all(
- query_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
- key_states = all_to_all(
- key_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
- value_states = all_to_all(
- value_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
-
- # (bs, n , qh // sp, d)
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- position_ids=position_ids,
- **kwargs,
- )
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- attn_output = all_to_all(
- attn_output, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
- @staticmethod
- def patched_attn_prefilling(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- position_ids: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin
- )
-
- fill_paged_kv_cache(
- key_states,
- value_states,
- past_key_value[self.layer_idx][0],
- past_key_value[self.layer_idx][1],
- kwargs["cu_seq_lens_q"],
- kwargs["cu_seq_lens_k"],
- kwargs["max_length_q"],
- kwargs["max_length_k"],
- kwargs["block_table"],
- )
-
- assert self.config._attn_implementation == "flash_attention_2"
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- position_ids=position_ids,
- **kwargs,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
- @staticmethod
- def patched_attn_decoding(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin
- )
-
- seq_lens_k = kwargs["cu_seq_lens_k"][1:] - kwargs["cu_seq_lens_k"][:-1]
- block_table = kwargs["block_table"]
- block_size = past_key_value[self.layer_idx][0].size(1)
- bs = block_table.size(0)
- assert kwargs["cu_seq_lens_k"].numel() - 1 == bs
-
- _key_states = key_states.transpose(1, 2).squeeze(0)
- _value_states = value_states.transpose(1, 2).squeeze(0)
-
- block_index = block_table[:, 0] + (seq_lens_k[:bs] - 1) // block_size
- past_key_value[self.layer_idx][0][
- block_index, (seq_lens_k[:bs] - 1) % block_size
- ] = _key_states
- past_key_value[self.layer_idx][1][
- block_index, (seq_lens_k[:bs] - 1) % block_size
- ] = _value_states
-
- assert self.config._attn_implementation == "flash_attention_2"
-
- attn_weights = None
-
- attn_output = paged_attention_decoding(
- query_states,
- past_key_value[self.layer_idx][0],
- past_key_value[self.layer_idx][1],
- kwargs["cu_seq_lens_k"][1:] - kwargs["cu_seq_lens_k"][:-1],
- block_table,
- )
-
- attn_output = attn_output.reshape(*input_shape, -1)
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
-
- @staticmethod
- def patched_layer_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[
- Tuple[torch.Tensor, torch.Tensor]
- ] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- if "block_table" in kwargs and kwargs["block_table"] is not None:
- if "prefilling" in kwargs and kwargs["prefilling"]:
- return CUDAPatchedLlamaForCausalLM.patched_layer_forward_training(
- self,
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- else:
- return CUDAPatchedLlamaForCausalLM.patched_layer_forward_decoding(
- self,
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- else:
- return CUDAPatchedLlamaForCausalLM.patched_layer_forward_training(
- self,
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
-
- @staticmethod
- # @torch.compile(fullgraph=True)
- def patched_layer_forward_training(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[
- Tuple[torch.Tensor, torch.Tensor]
- ] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
- @staticmethod
- @torch.compile(fullgraph=True)
- def patched_layer_forward_decoding(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[
- Tuple[torch.Tensor, torch.Tensor]
- ] = None, # necessary, but kept here for BC
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
-
- return outputs
-
- @staticmethod
- def patched_casual_forward(
- self: LlamaForCausalLM,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- label_shifted=False,
- **kwargs,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- output_attentions = (
- output_attentions
- if output_attentions is not None
- else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- **kwargs,
- )
-
- hidden_states = outputs[0]
-
- if labels is None:
- loss = None
-
- logits = self.lm_head(hidden_states)
- if isinstance(logits, DTensor):
- logits = logits.to_local()
- else:
- if liger_kernel_is_available():
- # unable to return logits when using Liger Kernel
- logits = None
-
- if label_shifted:
- shift_hidden_states = hidden_states
- shift_labels = labels
- else:
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
-
- shift_hidden_states = shift_hidden_states.view(
- -1, self.config.hidden_size
- )
- shift_labels = shift_labels.view(-1)
- shift_labels = shift_labels.to(shift_hidden_states.device)
-
- from liger_kernel.transformers.fused_linear_cross_entropy import (
- LigerFusedLinearCrossEntropyLoss,
- )
-
- loss_fct = LigerFusedLinearCrossEntropyLoss()
-
- lm_head_weight = self.lm_head.weight
- if isinstance(lm_head_weight, DTensor):
- assert isinstance(shift_hidden_states, DTensor)
- shift_hidden_states = shift_hidden_states.to_local()
- lm_head_weight = self.lm_head.weight.to_local()
-
- loss = loss_fct(
- lm_head_weight, shift_hidden_states, shift_labels, self.lm_head.bias
- )
-
- else:
- logits = self.lm_head(hidden_states)
-
- if label_shifted:
- shift_logits = logits
- shift_labels = labels
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
-
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- shift_labels = shift_labels.to(shift_logits.device)
-
- loss_fct = torch.nn.CrossEntropyLoss()
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- label_shifted: bool = False,
- gather_logprobs: bool = False,
- cu_seq_lens_q: Optional[torch.LongTensor] = None,
- cu_seq_lens_k: Optional[torch.LongTensor] = None,
- max_length_q: Optional[int] = None,
- max_length_k: Optional[int] = None,
- block_table: Optional[torch.LongTensor] = None,
- prefilling: bool = False,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- if gather_logprobs:
- assert labels is not None and label_shifted
-
- _input_ids = input_ids
- _labels = labels
- _position_ids = position_ids
- _cu_seq_lens_q = cu_seq_lens_q
- _cu_seq_lens_k = cu_seq_lens_k
- _max_length_q = max_length_q
- _max_length_k = max_length_k
-
- if self.fsdp_config.torch_compile:
- _input_ids = pad_to_max_length(
- _input_ids, 0, self.fsdp_config.max_length, 1
- )
- _position_ids = pad_to_max_length(
- _position_ids, 0, self.fsdp_config.max_length, 1
- )
- if labels is not None:
- _labels = pad_to_max_length(
- _labels, -100, self.fsdp_config.max_length, 1
- )
- else:
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- multiple_of = sequence_parallel_mesh.size() * self.tp_mesh.size()
- else:
- multiple_of = self.tp_mesh.size()
-
- _input_ids = pad_to_multiple_of(_input_ids, 0, multiple_of, 1)
- _position_ids = pad_to_multiple_of(_position_ids, 0, multiple_of, 1)
- if labels is not None:
- _labels = pad_to_multiple_of(_labels, -100, multiple_of, 1)
-
- num_padded_tokens = _input_ids.numel() - input_ids.numel()
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- _input_ids = split_for_sequence_parallel(
- _input_ids, dim=1, sp_mesh=sequence_parallel_mesh
- )
- _position_ids = split_for_sequence_parallel(
- _position_ids, dim=1, sp_mesh=sequence_parallel_mesh
- )
-
- if labels is not None:
- _labels = split_for_sequence_parallel(
- _labels, dim=1, sp_mesh=sequence_parallel_mesh
- )
-
- if self.tp_mesh.size() > 1:
- if labels is not None:
- _labels = split_for_sequence_parallel(
- _labels, dim=1, sp_mesh=self.tp_mesh
- )
-
- if self.training and num_padded_tokens > 0:
- assert torch.any(cu_seq_lens_k == cu_seq_lens_q)
- _cu_seq_lens_q = _cu_seq_lens_q.tolist()
- _cu_seq_lens_q.append(_cu_seq_lens_q[-1] + num_padded_tokens)
-
- _cu_seq_lens_q = torch.IntTensor(_cu_seq_lens_q).to(cu_seq_lens_q.device)
- _cu_seq_lens_k = _cu_seq_lens_q
-
- _max_length_q = max(_max_length_q, num_padded_tokens)
- _max_length_k = _max_length_q
-
- outputs = self.patched_model(
- _input_ids,
- attention_mask,
- _position_ids,
- past_key_values,
- inputs_embeds,
- _labels,
- use_cache,
- output_attentions,
- output_hidden_states,
- return_dict,
- cache_position,
- num_logits_to_keep,
- label_shifted=label_shifted,
- cu_seq_lens_q=_cu_seq_lens_q,
- cu_seq_lens_k=_cu_seq_lens_k,
- max_length_q=_max_length_q,
- max_length_k=_max_length_k,
- block_table=block_table,
- prefilling=prefilling,
- sequence_parallel_mesh=self.sequence_parallel_mesh,
- )
-
- if outputs.loss is not None:
- outputs.loss = outputs.loss * (_labels >= 0).sum()
- if self.tp_mesh.size() > 1:
- outputs.loss = dist.nn.all_reduce(
- outputs.loss, group=self.tp_mesh.get_group()
- )
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- outputs.loss = dist.nn.all_reduce(
- outputs.loss, group=sequence_parallel_mesh.get_group()
- )
- outputs.loss = outputs.loss / (labels >= 0).sum()
-
- return outputs
-
- @torch.no_grad()
- def sample(
- self,
- logits,
- cu_seq_lens,
- do_sample=True,
- top_k=0,
- top_p=0.9,
- temperature=1.0,
- vocab_size=None,
- ):
- last_token_inds = cu_seq_lens[1:] - 1
- rank_start = logits.size(0) * self.tp_mesh.get_local_rank()
- rank_end = logits.size(0) * (self.tp_mesh.get_local_rank() + 1)
-
- other_rank_mask = torch.logical_or(
- last_token_inds < rank_start, last_token_inds >= rank_end
- )
- last_token_inds -= rank_start
- last_token_inds = last_token_inds.clip(min=0, max=logits.size(0) - 1)
-
- logits = logits[last_token_inds]
-
- if vocab_size is not None:
- logits[:, vocab_size:] = -torch.inf
-
- if not do_sample:
- sampled = logits.argmax(-1)
- sampled[other_rank_mask] = 0
- if self.tp_mesh.size() > 1:
- dist.all_reduce(sampled, group=self.tp_mesh.get_group())
- return sampled
-
- # Apply temperature if necessary
- if temperature != 1.0:
- logits = logits / temperature
-
- # Apply top-k if necessary
- if top_k > 0:
- top_k = min(top_k, logits.size(-1))
- _, topk_indices = logits.topk(top_k, dim=-1)
- mask = torch.ones_like(logits, dtype=torch.bool)
- mask.scatter_(-1, topk_indices, False)
- logits.masked_fill_(mask, -torch.inf)
-
- # Apply top-p (nucleus sampling) if necessary
- if top_p < 1.0:
- sorted_logits, sorted_indices = torch.sort(logits, dim=-1)
- cum_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
-
- mask = cum_probs <= (1 - top_p)
- mask[:, -1] = False
- sorted_logits.masked_fill_(mask, -torch.inf)
-
- logits.scatter_(-1, sorted_indices, sorted_logits)
-
- probs = logits.softmax(-1)
- sampled = torch.multinomial(probs, 1).squeeze(-1)
- sampled[other_rank_mask] = 0
- if self.tp_mesh.size() > 1:
- dist.all_reduce(sampled, group=self.tp_mesh.get_group())
-
- return sampled
-
- def gather_logprobs(self, shifted_logits, shifted_labels, sequence_parallel_mesh):
- if self.fsdp_config.torch_compile:
- _labels = pad_to_max_length(
- shifted_labels, -100, self.fsdp_config.max_length, 1
- )
- else:
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- multiple_of = sequence_parallel_mesh.size() * self.tp_mesh.size()
- else:
- multiple_of = self.tp_mesh.size()
-
- _labels = pad_to_multiple_of(shifted_labels, -100, multiple_of, 1)
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- _labels = split_for_sequence_parallel(
- _labels, dim=1, sp_mesh=sequence_parallel_mesh
- )
-
- if self.tp_mesh.size() > 1:
- _labels = split_for_sequence_parallel(_labels, dim=1, sp_mesh=self.tp_mesh)
-
- logprobs = F.log_softmax(shifted_logits, dim=-1)
- logprobs = logprobs.gather(
- dim=-1, index=_labels.clip(min=0).unsqueeze(-1)
- ).squeeze(-1)
-
- if self.tp_mesh.size() > 1:
- _logprobs = dist.nn.all_gather(logprobs, group=self.tp_mesh.get_group())
- logprobs = torch.cat(_logprobs, dim=1)
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- _logprobs = dist.nn.all_gather(
- logprobs, group=sequence_parallel_mesh.get_group()
- )
- logprobs = torch.cat(_logprobs, dim=1)
-
- logprobs = logprobs[:, : shifted_labels.size(1)]
-
- return logprobs
-
- def trainable_parameters(self):
- _requried_grad_params = [
- param for param in self.patched_model.parameters() if param.requires_grad
- ]
- return _requried_grad_params
-
- def clip_grad_norm(self, max_norm):
- if self.tp_mesh.size() > 1:
- dist.all_reduce(
- self.patched_model.lm_head.weight.grad.to_local(),
- group=self.tp_mesh.get_group(),
- )
- dist.all_reduce(
- self.patched_model.model.norm.weight.grad.to_local(),
- group=self.tp_mesh.get_group(),
- )
- self.patched_model.lm_head.weight.grad.div_(self.tp_mesh.size())
- self.patched_model.model.norm.weight.grad.div_(self.tp_mesh.size())
-
- for param in self.trainable_parameters():
- param.grad.div_(self.tp_mesh.size())
-
- grad_norm = clip_grad_norm_(
- self.trainable_parameters(), self.world_mesh, max_norm
- )
- return grad_norm
-
-
-class MLUPatchedLlamaForCausalLM(CUDAPatchedLlamaForCausalLM):
- device_type = "mlu"
-
-
-class MuxiPatchedLlamaForCausalLM(CUDAPatchedLlamaForCausalLM):
- device_type = "muxi"
diff --git a/code/xtuner/_lite/patches/mixins/__init__.py b/code/xtuner/_lite/patches/mixins/__init__.py
deleted file mode 100644
index a3c940ad2cbb5eec2059bc2e3f12dad9d6d5fb96..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/mixins/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .generate import GenerateMixin
-
-__all__ = ["GenerateMixin"]
diff --git a/code/xtuner/_lite/patches/mixins/generate.py b/code/xtuner/_lite/patches/mixins/generate.py
deleted file mode 100644
index a3bd8389f6ff0dd75c5582a8469624be0811d269..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/mixins/generate.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-
-from xtuner._lite import get_logger
-from xtuner._lite.patches.utils import pack_sequence, packed_cumulative_length
-
-logger = get_logger()
-
-
-class GenerateMixin:
- @torch.no_grad()
- def build_kv_cache(
- self,
- max_batch_size,
- max_length,
- block_size=256,
- dtype=torch.bfloat16,
- device="cuda",
- ):
- num_blocks = max(max_batch_size, max_length // block_size * max_batch_size)
- head_dim = self.model_config.head_dim
- num_heads = self.model_config.num_key_value_heads
- past_key_values = []
- for _ in range(self.model_config.num_hidden_layers):
- cache_k = torch.zeros(
- num_blocks, block_size, num_heads, head_dim, dtype=dtype, device=device
- )
- cache_v = torch.zeros(
- num_blocks, block_size, num_heads, head_dim, dtype=dtype, device=device
- )
-
- past_key_values.append((cache_k, cache_v))
-
- block_table = torch.arange(num_blocks).reshape(max_batch_size, -1)
- return past_key_values, block_table
-
- @torch.no_grad()
- def prefilling(
- self,
- input_ids,
- position_ids,
- past_key_values,
- cu_seq_lens_q,
- cu_seq_lens_k,
- max_length_q,
- max_length_k,
- ):
- outputs = self(
- input_ids=input_ids,
- position_ids=position_ids,
- past_key_values=past_key_values,
- cache_position=position_ids,
- cu_seq_lens_q=cu_seq_lens_q,
- cu_seq_lens_k=cu_seq_lens_k,
- max_length_q=max_length_q,
- max_length_k=max_length_k,
- )
- return outputs.logits
-
- @torch.no_grad()
- def init_cuda_graph(
- self,
- input_ids,
- position_ids,
- past_key_values,
- cu_seq_lens_q,
- cu_seq_lens_k,
- max_length_q,
- max_length_k,
- block_table,
- ):
- s = torch.cuda.Stream()
- s.wait_stream(torch.cuda.current_stream())
-
- self.graph_block_table = block_table
- self.graph_cu_seq_lens_q = cu_seq_lens_q
- self.graph_cu_seq_lens_k = cu_seq_lens_k
- self.graph_max_length_q = max_length_q
- self.graph_max_length_k = max_length_k
- self.graph_input_ids = input_ids
- self.graph_position_ids = position_ids
- self.graph_cache_position = position_ids.clone()
-
- # 在新 stream 中预热
- with torch.cuda.stream(s):
- with torch.no_grad():
- self.graph_logits = self(
- input_ids=self.graph_input_ids,
- position_ids=self.graph_position_ids,
- past_key_values=past_key_values,
- cache_position=self.graph_cache_position,
- cu_seq_lens_q=self.graph_cu_seq_lens_q,
- cu_seq_lens_k=self.graph_cu_seq_lens_k,
- max_length_q=self.graph_max_length_q,
- max_length_k=self.graph_max_length_k,
- block_table=self.graph_block_table,
- prefilling=False,
- ).logits
-
- # 等待预热完成
- torch.cuda.current_stream().wait_stream(s)
- self.cuda_graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.cuda_graph):
- with torch.no_grad():
- self.graph_logits = self(
- input_ids=self.graph_input_ids,
- position_ids=self.graph_position_ids,
- past_key_values=past_key_values,
- cache_position=self.graph_cache_position,
- cu_seq_lens_q=self.graph_cu_seq_lens_q,
- cu_seq_lens_k=self.graph_cu_seq_lens_k,
- max_length_q=self.graph_max_length_q,
- max_length_k=self.graph_max_length_k,
- block_table=self.graph_block_table,
- prefilling=False,
- ).logits
-
- @torch.no_grad()
- def cuda_graph_decoding(
- self,
- input_ids,
- position_ids,
- past_key_values,
- cu_seq_lens_q,
- cu_seq_lens_k,
- max_length_q,
- max_length_k,
- block_table,
- ):
- if self.cuda_graph is None:
- self.init_cuda_graph(
- input_ids=input_ids,
- position_ids=position_ids,
- past_key_values=past_key_values,
- cu_seq_lens_q=cu_seq_lens_q,
- cu_seq_lens_k=cu_seq_lens_k,
- max_length_q=max_length_q,
- max_length_k=max_length_k,
- block_table=block_table,
- )
-
- self.graph_block_table.copy_(block_table)
- self.graph_cu_seq_lens_q.copy_(cu_seq_lens_q)
- self.graph_cu_seq_lens_k.copy_(cu_seq_lens_k)
- self.graph_max_length_q.copy_(max_length_q)
- self.graph_max_length_k.copy_(max_length_k)
- self.graph_input_ids.copy_(input_ids)
- self.graph_position_ids.copy_(position_ids)
- self.graph_cache_position.copy_(position_ids)
- self.cuda_graph.replay()
-
- return self.graph_logits
-
- @torch.no_grad()
- def generate(
- self,
- input_ids,
- stop_token_ids=[],
- max_batch_size=64,
- max_new_tokens=128,
- max_prefill_batch=16,
- max_length=2048,
- do_sample=False,
- top_k=0,
- top_p=1.0,
- temperature=1.0,
- cuda_graph=False,
- vocab_size=None,
- ):
- assert max_batch_size % max_prefill_batch == 0
- self.patched_model.config.use_cache = True
-
- past_key_values, block_table = self.build_kv_cache(
- max_batch_size, max_length, block_size=256, device=self.device_type
- )
-
- next_input_ids = []
- next_position_ids = []
- next_cu_seq_lens_q = []
- next_cu_seq_lens_k = []
- next_max_length_q = []
- next_max_length_k = []
- next_block_table = []
-
- for start in range(0, max_batch_size, max_prefill_batch):
- _packed_ids, _num_tokens = pack_sequence(
- input_ids[start : start + max_prefill_batch]
- )
- _position_ids = [
- torch.arange(seq.numel())
- for seq in input_ids[start : start + max_prefill_batch]
- ]
- _packed_pos_ids = torch.cat(_position_ids, dim=0).unsqueeze(0)
- _cumulative_length = packed_cumulative_length(_num_tokens)
-
- next_input_ids.append(_packed_ids.to(self.device_type))
- next_position_ids.append(_packed_pos_ids.to(self.device_type))
- next_cu_seq_lens_q.append(_cumulative_length.to(self.device_type))
- next_cu_seq_lens_k.append(_cumulative_length.to(self.device_type))
-
- next_max_length_q.append(_num_tokens.max().item())
- next_max_length_k.append(_num_tokens.max().item())
-
- next_block_table.append(
- block_table[start : start + max_prefill_batch]
- .to(self.device_type)
- .to(torch.int32)
- )
-
- next_is_prefilling = True
-
- num_sessions = len(input_ids)
- stopped = []
- responses = [[] for _ in range(num_sessions)]
-
- self.cuda_graph = None
- self.compiled_model = None
- while True:
- all_rank_stopped = torch.IntTensor([len(stopped) >= num_sessions]).to(
- self.device_type
- )
- torch.distributed.all_reduce(
- all_rank_stopped, torch.distributed.ReduceOp.MIN
- )
- if all_rank_stopped:
- break
-
- if next_is_prefilling:
- if isinstance(next_input_ids, list):
- sampled = []
- seq_lens_q = []
- seq_lens_k = []
- for (
- chunk_input_ids,
- chunk_pos_ids,
- chunk_cu_seq_lens_q,
- chunk_max_length_q,
- chunk_cu_seq_lens_k,
- chunk_max_length_k,
- chunk_block_table,
- ) in zip(
- next_input_ids,
- next_position_ids,
- next_cu_seq_lens_q,
- next_max_length_q,
- next_cu_seq_lens_k,
- next_max_length_k,
- next_block_table,
- strict=True,
- ):
- chunk_outputs = self(
- input_ids=chunk_input_ids,
- position_ids=chunk_pos_ids,
- past_key_values=past_key_values,
- cache_position=chunk_pos_ids,
- cu_seq_lens_q=chunk_cu_seq_lens_q,
- cu_seq_lens_k=chunk_cu_seq_lens_k,
- max_length_q=chunk_max_length_q,
- max_length_k=chunk_max_length_k,
- block_table=chunk_block_table,
- prefilling=next_is_prefilling,
- )
- chunk_logits = chunk_outputs.logits
-
- chunk_sampled = self.sample(
- chunk_logits[0],
- cu_seq_lens=chunk_cu_seq_lens_q,
- do_sample=do_sample,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- )
-
- chunk_seq_lens_q = (
- chunk_cu_seq_lens_q[1:] - chunk_cu_seq_lens_q[:-1]
- )
- chunk_seq_lens_k = (
- chunk_cu_seq_lens_k[1:] - chunk_cu_seq_lens_k[:-1]
- )
-
- sampled.append(chunk_sampled)
- seq_lens_q.append(chunk_seq_lens_q)
- seq_lens_k.append(chunk_seq_lens_k)
-
- sampled = torch.cat(sampled)
- next_input_ids = torch.cat(next_input_ids, dim=1)
- next_position_ids = torch.cat(next_position_ids, dim=1)
- next_block_table = torch.cat(next_block_table, dim=0)
-
- seq_lens_q = torch.cat(seq_lens_q)
- seq_lens_k = torch.cat(seq_lens_k)
-
- next_cu_seq_lens_q = packed_cumulative_length(seq_lens_q)
- next_cu_seq_lens_k = packed_cumulative_length(seq_lens_k)
- next_max_length_q = seq_lens_q.max()
- next_max_length_k = seq_lens_k.max()
-
- else:
- if cuda_graph:
- logits = self.cuda_graph_decoding(
- input_ids=next_input_ids,
- position_ids=next_position_ids,
- past_key_values=past_key_values,
- cu_seq_lens_q=next_cu_seq_lens_q,
- cu_seq_lens_k=next_cu_seq_lens_k,
- max_length_q=next_max_length_q,
- max_length_k=next_max_length_k,
- block_table=next_block_table,
- )
- else:
- outputs = self(
- input_ids=next_input_ids,
- position_ids=next_position_ids,
- past_key_values=past_key_values,
- cache_position=next_position_ids,
- cu_seq_lens_q=next_cu_seq_lens_q,
- cu_seq_lens_k=next_cu_seq_lens_k,
- max_length_q=next_max_length_q,
- max_length_k=next_max_length_k,
- block_table=next_block_table,
- prefilling=next_is_prefilling,
- )
- logits = outputs.logits
-
- sampled = self.sample(
- logits[0],
- cu_seq_lens=next_cu_seq_lens_q,
- do_sample=do_sample,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- )
-
- _next_input_ids = []
- _next_position_ids = []
- _next_seq_lens_q = []
- _next_seq_lens_k = []
- _next_block_table = []
-
- for sess_id in range(num_sessions):
- if sess_id not in stopped:
- token_id = sampled[sess_id]
- responses[sess_id].append(token_id)
- else:
- token_id = responses[sess_id][-1]
-
- _sess_new_tokens = len(responses[sess_id])
- _sess_len = _sess_new_tokens + input_ids[sess_id].numel()
-
- stop = (
- _sess_new_tokens >= max_new_tokens
- or _sess_len >= max_length
- or token_id in stop_token_ids
- )
-
- if stop and sess_id not in stopped:
- stopped.append(sess_id)
-
- _next_block_table.append(next_block_table[sess_id])
- _next_input_ids.append(token_id.reshape(1, -1))
- _next_position_ids.append(torch.arange(_sess_len - 1, _sess_len))
- _next_seq_lens_q.append(1)
- _next_seq_lens_k.append(_sess_len)
-
- _packed_ids, _num_tokens = pack_sequence(_next_input_ids)
- _cumulative_length = packed_cumulative_length(_num_tokens)
-
- next_input_ids = _packed_ids.to(self.device_type)
- next_position_ids = torch.cat(_next_position_ids, dim=0).unsqueeze(0)
- next_position_ids = next_position_ids.to(self.device_type)
-
- _next_seq_lens_q = torch.IntTensor([0] + _next_seq_lens_q).to(
- self.device_type
- )
- _next_seq_lens_k = torch.IntTensor([0] + _next_seq_lens_k).to(
- self.device_type
- )
-
- next_max_length_q = _next_seq_lens_q.max()
- next_max_length_k = _next_seq_lens_k.max()
-
- next_cu_seq_lens_q = torch.cumsum(_next_seq_lens_q, dim=0).int()
- next_cu_seq_lens_k = torch.cumsum(_next_seq_lens_k, dim=0).int()
-
- next_block_table = torch.stack(_next_block_table).to(self.device_type)
-
- next_is_prefilling = False
-
- self.patched_model.config.use_cache = False
-
- del past_key_values
- self.cuda_graph = None
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
-
- return [torch.stack(res).tolist() for res in responses]
diff --git a/code/xtuner/_lite/patches/qwen2.py b/code/xtuner/_lite/patches/qwen2.py
deleted file mode 100644
index cccb893c578935d84944a0b733ee312f7ab79c7c..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/qwen2.py
+++ /dev/null
@@ -1,217 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Callable, Optional, Tuple
-
-import torch
-from torch.distributed.device_mesh import DeviceMesh
-from transformers.cache_utils import Cache
-from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-from transformers.models.qwen2.modeling_qwen2 import (
- Qwen2Attention,
- Qwen2DecoderLayer,
- Qwen2ForCausalLM,
- Qwen2RMSNorm,
- Qwen2RotaryEmbedding,
- apply_rotary_pos_emb,
- eager_attention_forward,
- repeat_kv,
-)
-from transformers.processing_utils import Unpack
-from transformers.utils import logging
-
-from xtuner._lite.chat import HybridChatTemplate
-from xtuner._lite.patches.base import FSDPConfig, ModelConfig
-from xtuner._lite.patches.llama import CUDAPatchedLlamaForCausalLM, all_to_all
-
-logger = logging.get_logger(__name__)
-
-
-class CUDAPatchedQwen2ForCausalLM(CUDAPatchedLlamaForCausalLM):
- rotary_emb_cls = Qwen2RotaryEmbedding
- attn_cls = Qwen2Attention
- layer_cls = Qwen2DecoderLayer
- causal_cls = Qwen2ForCausalLM
- norm_cls = Qwen2RMSNorm
-
- chat_template = HybridChatTemplate(
- system="<|im_start|>system\n{system}<|im_end|>\n",
- user="<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n",
- assistant="{assistant}<|im_end|>",
- stop_words=["<|im_end|>", "<|endoftext|>"],
- )
-
- def init_model_config(self, fsdp_config: FSDPConfig):
- assert self.patched_model.config.num_key_value_heads >= fsdp_config.tp_size
- assert self.patched_model.config.num_key_value_heads % fsdp_config.tp_size == 0
- assert (
- self.patched_model.config.hidden_size
- % self.patched_model.config.num_attention_heads
- == 0
- )
-
- self._model_config = ModelConfig(
- num_hidden_layers=self.patched_model.config.num_hidden_layers,
- num_attention_heads=self.patched_model.config.num_attention_heads,
- num_key_value_heads=self.patched_model.config.num_key_value_heads
- // fsdp_config.tp_size,
- hidden_size=self.patched_model.config.hidden_size,
- intermediate_size=self.patched_model.config.intermediate_size,
- vocab_size=self.patched_model.config.vocab_size,
- head_dim=self.patched_model.config.hidden_size
- // self.patched_model.config.num_attention_heads,
- )
-
- @staticmethod
- def patched_attn_forward(
- self: Qwen2Attention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- position_ids: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = False,
- past_key_value: Optional[Cache] = None,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if "block_table" in kwargs and kwargs["block_table"] is not None:
- if (
- self.config.use_sliding_window
- and getattr(self.config, "sliding_window", None) is not None
- and self.layer_idx >= self.config.max_window_layers
- ):
- raise NotImplementedError
-
- # generating
- if "prefilling" in kwargs and kwargs["prefilling"]:
- return CUDAPatchedLlamaForCausalLM.patched_attn_prefilling(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
- else:
- return CUDAPatchedLlamaForCausalLM.patched_attn_decoding(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
- else:
- return CUDAPatchedQwen2ForCausalLM.patched_attn_forward_training(
- self,
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- past_key_value=past_key_value,
- position_ids=position_ids,
- cache_position=cache_position,
- output_attentions=output_attentions,
- sequence_parallel_mesh=sequence_parallel_mesh,
- **kwargs,
- )
-
- @staticmethod
- def patched_attn_forward_training(
- self: Qwen2Attention,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- sequence_parallel_mesh: Optional[DeviceMesh] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin
- )
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs
- )
-
- sliding_window = None
- if (
- self.config.use_sliding_window
- and getattr(self.config, "sliding_window", None) is not None
- and self.layer_idx >= self.config.max_window_layers
- ):
- sliding_window = self.config.sliding_window
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get(
- "output_attentions", False
- ):
- logger.warning_once(
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. "
- "Falling back to eager attention. This warning can be removed using the argument "
- '`attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[
- self.config._attn_implementation
- ]
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- sp_size = sequence_parallel_mesh.size()
- num_kv_heads = key_states.size(1)
- if sp_size > num_kv_heads:
- assert sp_size % num_kv_heads == 0
- key_states = repeat_kv(key_states, sp_size // num_kv_heads)
- value_states = repeat_kv(value_states, sp_size // num_kv_heads)
-
- query_states = all_to_all(
- query_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
- key_states = all_to_all(
- key_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
- value_states = all_to_all(
- value_states, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
-
- # (bs, n , qh // sp, d)
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=sliding_window,
- **kwargs,
- )
-
- if sequence_parallel_mesh and sequence_parallel_mesh.size() > 1:
- attn_output = all_to_all(
- attn_output, scatter_dim=1, gather_dim=2, mesh=sequence_parallel_mesh
- )
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
diff --git a/code/xtuner/_lite/patches/utils.py b/code/xtuner/_lite/patches/utils.py
deleted file mode 100644
index 9e847c627f8eb04218e4ec47dc096f4646113253..0000000000000000000000000000000000000000
--- a/code/xtuner/_lite/patches/utils.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import List, Union
-
-import torch
-
-
-def pad_to_multiple_of(sequence, padding_value, multiple_of, dim=-1):
- length = sequence.shape[dim]
- if length % multiple_of == 0:
- return sequence
-
- pad_num = multiple_of - (length % multiple_of)
- pad_shape = (
- (*sequence.shape[:dim], pad_num, *sequence.shape[dim + 1 :])
- if dim != -1
- else (*sequence.shape[:dim], pad_num)
- )
- pad = torch.full(
- pad_shape, padding_value, dtype=sequence.dtype, device=sequence.device
- )
- sequence = torch.cat([sequence, pad], dim=dim)
- return sequence
-
-
-def pad_to_max_length(sequence, padding_value, max_length, dim=-1):
- length = sequence.shape[dim]
- assert length <= max_length
- pad_num = max_length - length
- pad_shape = (
- (*sequence.shape[:dim], pad_num, *sequence.shape[dim + 1 :])
- if dim != -1
- else (*sequence.shape[:dim], pad_num)
- )
- pad = torch.full(
- pad_shape, padding_value, dtype=sequence.dtype, device=sequence.device
- )
- sequence = torch.cat([sequence, pad], dim=dim)
- return sequence
-
-
-def unpack_sequence(packed: torch.Tensor, num_tokens: Union[torch.Tensor, List], dim=1):
- if isinstance(num_tokens, torch.Tensor):
- num_tokens = num_tokens.tolist()
- sequences = torch.split(packed, num_tokens, dim=dim)
- return sequences
-
-
-def pack_sequence(sequences, dim=1):
- num_tokens = torch.IntTensor([seq.size(dim) for seq in sequences])
- packed = torch.cat(sequences, dim=dim)
- return packed, num_tokens.to(packed.device)
-
-
-def packed_cumulative_length(num_tokens: torch.Tensor):
- device = num_tokens.device
- _zero_pad = torch.zeros(1, device=device)
- _pad_length = torch.cat([_zero_pad, num_tokens]).int()
- return torch.cumsum(_pad_length, 0).int()
diff --git a/code/xtuner/apis/__init__.py b/code/xtuner/apis/__init__.py
deleted file mode 100644
index f49d493789960175c39a59a0b62e0fae44513766..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .datasets import * # noqa: F401, F403
-from .model import * # noqa: F401, F403
-from .training_args import * # noqa: F401, F403
diff --git a/code/xtuner/apis/datasets/__init__.py b/code/xtuner/apis/datasets/__init__.py
deleted file mode 100644
index 4ff4fe4789522dd117c77fe74e1c381ead461e91..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/__init__.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .alpaca import (alpaca_data_collator, alpaca_dataset,
- alpaca_enzh_data_collator, alpaca_enzh_dataset,
- alpaca_zh_data_collator, alpaca_zh_dataset)
-from .arxiv import arxiv_data_collator, arxiv_dataset
-from .code_alpaca import code_alpaca_data_collator, code_alpaca_dataset
-from .colorist import colorist_data_collator, colorist_dataset
-from .lawyer import (lawyer_crime_data_collator, lawyer_crime_dataset,
- lawyer_data_collator, lawyer_dataset,
- lawyer_reference_data_collator, lawyer_reference_dataset)
-from .medical import medical_data_collator, medical_dataset
-from .moss_003_sft import (moss_003_sft_data_collator, moss_003_sft_dataset,
- moss_003_sft_no_plugins_data_collator,
- moss_003_sft_no_plugins_dataset,
- moss_003_sft_plugins_data_collator,
- moss_003_sft_plugins_dataset)
-from .oasst1 import oasst1_data_collator, oasst1_dataset
-from .open_orca import openorca_data_collator, openorca_dataset
-from .sql import sql_data_collator, sql_dataset
-from .tiny_codes import tiny_codes_data_collator, tiny_codes_dataset
-from .wizardlm import wizardlm_data_collator, wizardlm_dataset
-
-__all__ = [
- 'alpaca_data_collator', 'alpaca_dataset', 'alpaca_enzh_data_collator',
- 'alpaca_enzh_dataset', 'alpaca_zh_data_collator', 'alpaca_zh_dataset',
- 'arxiv_data_collator', 'arxiv_dataset', 'medical_data_collator',
- 'medical_dataset', 'moss_003_sft_data_collator', 'moss_003_sft_dataset',
- 'moss_003_sft_no_plugins_data_collator', 'moss_003_sft_no_plugins_dataset',
- 'moss_003_sft_plugins_data_collator', 'moss_003_sft_plugins_dataset',
- 'oasst1_data_collator', 'oasst1_dataset', 'openorca_data_collator',
- 'openorca_dataset', 'lawyer_crime_dataset', 'lawyer_crime_data_collator',
- 'lawyer_reference_dataset', 'lawyer_reference_data_collator',
- 'lawyer_dataset', 'lawyer_data_collator', 'colorist_dataset',
- 'colorist_data_collator', 'sql_dataset', 'sql_data_collator',
- 'code_alpaca_dataset', 'code_alpaca_data_collator', 'tiny_codes_dataset',
- 'tiny_codes_data_collator', 'wizardlm_data_collator', 'wizardlm_dataset'
-]
diff --git a/code/xtuner/apis/datasets/alpaca.py b/code/xtuner/apis/datasets/alpaca.py
deleted file mode 100644
index 8e284a9375f8ae286083e29c1ba92549414caff5..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/alpaca.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-from torch.utils.data import ConcatDataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import (alpaca_map_fn, alpaca_zh_map_fn,
- template_map_fn_factory)
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def alpaca_enzh_dataset(tokenizer,
- path_en='tatsu-lab/alpaca',
- path_zh='silk-road/alpaca-data-gpt4-chinese',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- alpaca = alpaca_dataset(
- tokenizer,
- path=path_en,
- max_length=max_length,
- prompt_template=prompt_template,
- shuffle_before_pack=True,
- remove_unused_columns=remove_unused_columns,
- pack_to_max_length=pack_to_max_length)
- alpaca_zh = alpaca_zh_dataset(
- tokenizer,
- path=path_zh,
- max_length=max_length,
- prompt_template=prompt_template,
- shuffle_before_pack=True,
- remove_unused_columns=remove_unused_columns,
- pack_to_max_length=pack_to_max_length)
- dataset = ConcatDataset([alpaca, alpaca_zh])
- return dataset
-
-
-def alpaca_enzh_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def alpaca_zh_dataset(tokenizer,
- path='silk-road/alpaca-data-gpt4-chinese',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=alpaca_zh_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def alpaca_zh_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def alpaca_dataset(tokenizer,
- path='tatsu-lab/alpaca',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=alpaca_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def alpaca_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/arxiv.py b/code/xtuner/apis/datasets/arxiv.py
deleted file mode 100644
index 35521f3ea80b67fec779576a48de4779d59a8bb4..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/arxiv.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import arxiv_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def arxiv_dataset(tokenizer,
- data_file=None,
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- # 1. Download data from https://kaggle.com/datasets/Cornell-University/arxiv # noqa: E501
- # 2. Process data with `./tools/data_preprocess/arxiv.py`
- if data_file is None:
- data_file = './data/arxiv_postprocess_csAIcsCLcsCV_20200101.json'
- dataset_org = load_dataset(path='json', data_files=dict(train=data_file))
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=arxiv_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def arxiv_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/code_alpaca.py b/code/xtuner/apis/datasets/code_alpaca.py
deleted file mode 100644
index de3f94e24fb529932894143ee1a477ec1d06221e..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/code_alpaca.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import code_alpaca_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def code_alpaca_dataset(tokenizer,
- path='HuggingFaceH4/CodeAlpaca_20K',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=code_alpaca_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def code_alpaca_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/colorist.py b/code/xtuner/apis/datasets/colorist.py
deleted file mode 100644
index 00400d09e62be767b026a170ee7c2aaad26e6f97..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/colorist.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import colors_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def colorist_dataset(tokenizer,
- path='burkelibbey/colors',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=colors_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def colorist_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/lawyer.py b/code/xtuner/apis/datasets/lawyer.py
deleted file mode 100644
index 287dc2827d7cdf6ab54649af2434b9e270b8f155..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/lawyer.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-from torch.utils.data import ConcatDataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import (crime_kg_assitant_map_fn,
- law_reference_map_fn,
- template_map_fn_factory)
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def lawyer_dataset(tokenizer,
- crime_data_file=None,
- reference_data_file=None,
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- crime_dataset = lawyer_crime_dataset(
- tokenizer,
- data_file=crime_data_file,
- max_length=max_length,
- prompt_template=prompt_template,
- remove_unused_columns=remove_unused_columns,
- pack_to_max_length=pack_to_max_length)
- reference_dataset = lawyer_reference_dataset(
- tokenizer,
- data_file=reference_data_file,
- max_length=max_length,
- prompt_template=prompt_template,
- remove_unused_columns=remove_unused_columns,
- pack_to_max_length=pack_to_max_length)
- dataset = ConcatDataset([crime_dataset, reference_dataset])
- return dataset
-
-
-def lawyer_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def lawyer_crime_dataset(tokenizer,
- data_file=None,
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- # Download data from https://github.com/LiuHC0428/LAW-GPT # noqa: E501
- if data_file is None:
- data_file = './data/law/CrimeKgAssitant清洗后_52k.json'
- dataset_org = load_dataset(path='json', data_files=dict(train=data_file))
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=crime_kg_assitant_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def lawyer_crime_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def lawyer_reference_dataset(tokenizer,
- data_file=None,
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- # Download data from https://github.com/LiuHC0428/LAW-GPT # noqa: E501
- if data_file is None:
- data_file = './data/law/训练数据_带法律依据_92k.json'
- dataset_org = load_dataset(path='json', data_files=dict(train=data_file))
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=law_reference_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def lawyer_reference_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/medical.py b/code/xtuner/apis/datasets/medical.py
deleted file mode 100644
index cd430b8d24a7cc007be5d1677273a9172071d72b..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/medical.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import medical_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def medical_dataset(tokenizer,
- path='shibing624/medical',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=False,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=medical_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def medical_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/moss_003_sft.py b/code/xtuner/apis/datasets/moss_003_sft.py
deleted file mode 100644
index 7952238cf00132b142a8a0877d0e104424a49bcc..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/moss_003_sft.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from torch.utils.data import ConcatDataset
-
-from xtuner.dataset import MOSSSFTDataset
-from xtuner.dataset.collate_fns import default_collate_fn
-
-
-def moss_003_sft_dataset(tokenizer,
- plugins_data_file=None,
- no_plugins_data_file=None,
- bot_name=None,
- max_length=2048):
- plugins = moss_003_sft_plugins_dataset(
- tokenizer,
- data_file=plugins_data_file,
- bot_name=bot_name,
- max_length=max_length)
- no_plugins = moss_003_sft_no_plugins_dataset(
- tokenizer,
- data_file=no_plugins_data_file,
- bot_name=bot_name,
- max_length=max_length)
- dataset = ConcatDataset([plugins, no_plugins])
- return dataset
-
-
-def moss_003_sft_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def moss_003_sft_no_plugins_dataset(tokenizer,
- data_file=None,
- bot_name=None,
- max_length=2048):
-
- # Download data from https://huggingface.co/datasets/fnlp/moss-003-sft-data
- if data_file is None:
- data_file = './data/moss-003-sft-no-tools.jsonl'
- dataset = MOSSSFTDataset(
- data_file=data_file,
- bot_name=bot_name,
- tokenizer=tokenizer,
- max_length=max_length)
-
- return dataset
-
-
-def moss_003_sft_no_plugins_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
-
-
-def moss_003_sft_plugins_dataset(tokenizer,
- data_file=None,
- bot_name=None,
- max_length=2048):
-
- # Download data from https://huggingface.co/datasets/fnlp/moss-003-sft-data
- if data_file is None:
- data_file = './data/conversations_with_tools_with_inner_instruction_no_text2image_train_all_random_meta0.5_0.1_0.01_moss_0709.jsonl' # noqa: E501
- dataset = MOSSSFTDataset(
- data_file=data_file,
- bot_name=bot_name,
- tokenizer=tokenizer,
- max_length=max_length)
-
- return dataset
-
-
-def moss_003_sft_plugins_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/oasst1.py b/code/xtuner/apis/datasets/oasst1.py
deleted file mode 100644
index 0b877239622ed68bc886efcf13a2936772005118..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/oasst1.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def oasst1_dataset(tokenizer,
- path='timdettmers/openassistant-guanaco',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=False,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=oasst1_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def oasst1_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/open_orca.py b/code/xtuner/apis/datasets/open_orca.py
deleted file mode 100644
index 9e52d50e2271005ef87ac1952a13fe391b77a207..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/open_orca.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import openorca_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def openorca_dataset(tokenizer,
- path='Open-Orca/OpenOrca',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=openorca_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def openorca_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/sql.py b/code/xtuner/apis/datasets/sql.py
deleted file mode 100644
index fed725ee05707fe455b7cfcf4cc8bf1621f32696..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/sql.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import sql_map_fn, template_map_fn_factory
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def sql_dataset(tokenizer,
- path='b-mc2/sql-create-context',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=sql_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def sql_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/tiny_codes.py b/code/xtuner/apis/datasets/tiny_codes.py
deleted file mode 100644
index 286d65e4f6e1e13b831e52f15ad98fc072a72719..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/tiny_codes.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import template_map_fn_factory, tiny_codes_map_fn
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def tiny_codes_dataset(tokenizer,
- path='nampdn-ai/tiny-codes',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=True,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=tiny_codes_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def tiny_codes_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/datasets/wizardlm.py b/code/xtuner/apis/datasets/wizardlm.py
deleted file mode 100644
index b5a084271075da12577fd0560b8572e9cd0eeb20..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/datasets/wizardlm.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from datasets import load_dataset
-
-from xtuner.dataset import process_hf_dataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import template_map_fn_factory, wizardlm_map_fn
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-def wizardlm_dataset(tokenizer,
- path='WizardLM/WizardLM_evol_instruct_V2_196k',
- max_length=2048,
- prompt_template=PROMPT_TEMPLATE.default,
- remove_unused_columns=False,
- pack_to_max_length=True):
- template_map_fn = template_map_fn_factory(template=prompt_template)
- dataset_org = load_dataset(path)
- dataset = process_hf_dataset(
- dataset=dataset_org,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=wizardlm_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=remove_unused_columns,
- shuffle_before_pack=True,
- pack_to_max_length=pack_to_max_length)
-
- return dataset
-
-
-def wizardlm_data_collator(return_hf_format=False):
- return partial(default_collate_fn, return_hf_format=return_hf_format)
diff --git a/code/xtuner/apis/model.py b/code/xtuner/apis/model.py
deleted file mode 100644
index efd9370df19a1b258fa7c93ef31284fff42dd589..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/model.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from peft import LoraConfig
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig)
-
-from xtuner.model import SupervisedFinetune
-
-__all__ = ['build_model', 'build_lora_model', 'build_qlora_model']
-
-
-def build_qlora_model(model_name_or_path,
- quantization_config=None,
- lora_config=None,
- return_tokenizer=True):
-
- if quantization_config is None:
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4')
- if lora_config is None:
- lora_config = LoraConfig(
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM')
-
- llm = AutoModelForCausalLM.from_pretrained(
- model_name_or_path,
- torch_dtype=torch.float16,
- trust_remote_code=True,
- quantization_config=quantization_config)
-
- model = SupervisedFinetune(llm, lora=lora_config)
-
- if return_tokenizer:
- tokenizer = AutoTokenizer.from_pretrained(
- model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- return model.llm, tokenizer
- else:
- return model.llm
-
-
-def build_lora_model(model_name_or_path,
- lora_config=None,
- return_tokenizer=True):
- if lora_config is None:
- lora_config = LoraConfig(
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM')
-
- llm = AutoModelForCausalLM.from_pretrained(
- model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True)
-
- model = SupervisedFinetune(llm, lora=lora_config)
-
- if return_tokenizer:
- tokenizer = AutoTokenizer.from_pretrained(
- model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- return model.llm, tokenizer
- else:
- return model.llm
-
-
-def build_model(model_name_or_path, return_tokenizer=True):
- model = AutoModelForCausalLM.from_pretrained(
- model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True)
-
- if return_tokenizer:
- tokenizer = AutoTokenizer.from_pretrained(
- model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- return model, tokenizer
- else:
- return model
diff --git a/code/xtuner/apis/training_args.py b/code/xtuner/apis/training_args.py
deleted file mode 100644
index b0f65445c2e273e43244682e035e8e0a729bdd31..0000000000000000000000000000000000000000
--- a/code/xtuner/apis/training_args.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from dataclasses import dataclass, field
-from typing import Union
-
-from transformers import TrainingArguments
-from transformers.trainer_utils import IntervalStrategy, SchedulerType
-
-__all__ = ['DefaultTrainingArguments']
-
-
-@dataclass
-class DefaultTrainingArguments(TrainingArguments):
- # custom
- model_name_or_path: str = field(
- default=None,
- metadata={'help': 'model name or path.'},
- )
- dataset_name_or_path: str = field(
- default=None,
- metadata={'help': 'dataset name or path.'},
- )
-
- # huggingface
- default_output_dir = './work_dirs'
- default_do_train = True
- default_per_device_train_batch_size = 1
- default_learning_rate = 2e-5
- default_save_strategy = 'epoch'
- default_lr_scheduler_type = 'cosine'
- default_logging_steps = 5
-
- output_dir: str = field(
- default=default_output_dir,
- metadata={
- 'help': ('The output directory where the model predictions and '
- 'checkpoints will be written.')
- })
- do_train: bool = field(
- default=default_do_train,
- metadata={'help': 'Whether to run training.'})
- per_device_train_batch_size: int = field(
- default=default_per_device_train_batch_size,
- metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'})
- learning_rate: float = field(
- default=default_learning_rate,
- metadata={'help': 'The initial learning rate for AdamW.'})
- save_strategy: Union[IntervalStrategy, str] = field(
- default=default_save_strategy,
- metadata={'help': 'The checkpoint save strategy to use.'},
- )
- lr_scheduler_type: Union[SchedulerType, str] = field(
- default=default_lr_scheduler_type,
- metadata={'help': 'The scheduler type to use.'},
- )
- logging_steps: float = field(
- default=default_logging_steps,
- metadata={
- 'help': ('Log every X updates steps. Should be an integer or a '
- 'float in range `[0,1)`. If smaller than 1, will be '
- 'interpreted as ratio of total training steps.')
- })
diff --git a/code/xtuner/configs/__init__.py b/code/xtuner/configs/__init__.py
deleted file mode 100644
index 7dc9a90264f47ea1294814e81d742954a689a98d..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-
-
-def get_cfgs_name_path():
- path = os.path.dirname(__file__)
- mapping = {}
- for root, dirs, files in os.walk(path):
- root = "/data/qingq/PathVLM/baselines/github/SlideChat/xtuner"
- for file_ in files:
- if file_.endswith(
- ('.py', '.json')
- ) and not file_.startswith('.') and not file_.startswith('_'):
- mapping[os.path.splitext(file_)[0]] = os.path.join(root, file_)
- return mapping
-
-
-cfgs_name_path = get_cfgs_name_path()
-
-__all__ = ['cfgs_name_path']
diff --git a/code/xtuner/configs/__pycache__/__init__.cpython-311.pyc b/code/xtuner/configs/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 99e3e7afdfd19d4d14ec9eaf5148022cee45bd78..0000000000000000000000000000000000000000
Binary files a/code/xtuner/configs/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/configs/deepspeed/deepspeed_zero1.json b/code/xtuner/configs/deepspeed/deepspeed_zero1.json
deleted file mode 100644
index 552176573381f23d87f79b2b3990302dd6a69039..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/deepspeed/deepspeed_zero1.json
+++ /dev/null
@@ -1,18 +0,0 @@
-{
- "gradient_accumulation_steps": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "gradient_clipping": "auto",
- "zero_allow_untested_optimizer": true,
- "zero_force_ds_cpu_optimizer": false,
- "zero_optimization": {
- "stage": 1,
- "overlap_comm": true
- },
- "fp16": {
- "enabled": "auto",
- "initial_scale_power": 16
- },
- "bf16": {
- "enabled": "auto"
- }
-}
diff --git a/code/xtuner/configs/deepspeed/deepspeed_zero2.json b/code/xtuner/configs/deepspeed/deepspeed_zero2.json
deleted file mode 100644
index 8505b490890978b0d93f2f74be9d5e74ab31428b..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/deepspeed/deepspeed_zero2.json
+++ /dev/null
@@ -1,18 +0,0 @@
-{
- "gradient_accumulation_steps": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "gradient_clipping": "auto",
- "zero_allow_untested_optimizer": true,
- "zero_force_ds_cpu_optimizer": false,
- "zero_optimization": {
- "stage": 2,
- "overlap_comm": false
- },
- "fp16": {
- "enabled": "auto",
- "initial_scale_power": 16
- },
- "bf16": {
- "enabled": "auto"
- }
-}
diff --git a/code/xtuner/configs/deepspeed/deepspeed_zero2_offload.json b/code/xtuner/configs/deepspeed/deepspeed_zero2_offload.json
deleted file mode 100644
index 4376d2ddb1b16cc26db7c50cea24d2f91de21aff..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/deepspeed/deepspeed_zero2_offload.json
+++ /dev/null
@@ -1,22 +0,0 @@
-{
- "gradient_accumulation_steps": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "gradient_clipping": "auto",
- "zero_allow_untested_optimizer": true,
- "zero_force_ds_cpu_optimizer": false,
- "zero_optimization": {
- "stage": 2,
- "overlap_comm": true,
- "offload_optimizer": {
- "device": "cpu",
- "pin_memory": true
- }
- },
- "fp16": {
- "enabled": "auto",
- "initial_scale_power": 16
- },
- "bf16": {
- "enabled": "auto"
- }
-}
diff --git a/code/xtuner/configs/deepspeed/deepspeed_zero3.json b/code/xtuner/configs/deepspeed/deepspeed_zero3.json
deleted file mode 100644
index 2922a1cb6332e5aab833f2cbdc9d34ca355cb15b..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/deepspeed/deepspeed_zero3.json
+++ /dev/null
@@ -1,20 +0,0 @@
-{
- "gradient_accumulation_steps": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "gradient_clipping": "auto",
- "zero_allow_untested_optimizer": true,
- "zero_force_ds_cpu_optimizer": false,
- "zero_optimization": {
- "stage": 3,
- "overlap_comm": true,
- "stage3_gather_16bit_weights_on_model_save": true,
- "stage3_param_persistence_threshold": 0
- },
- "fp16": {
- "enabled": "auto",
- "initial_scale_power": 16
- },
- "bf16": {
- "enabled": "auto"
- }
-}
diff --git a/code/xtuner/configs/deepspeed/deepspeed_zero3_offload.json b/code/xtuner/configs/deepspeed/deepspeed_zero3_offload.json
deleted file mode 100644
index 02ec67912817fa2039547a7c081e71e8bfbbe0d6..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/deepspeed/deepspeed_zero3_offload.json
+++ /dev/null
@@ -1,28 +0,0 @@
-{
- "gradient_accumulation_steps": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "gradient_clipping": "auto",
- "zero_allow_untested_optimizer": true,
- "zero_force_ds_cpu_optimizer": false,
- "zero_optimization": {
- "stage": 3,
- "overlap_comm": false,
- "stage3_gather_16bit_weights_on_model_save": true,
- "stage3_param_persistence_threshold": 1e6,
- "offload_optimizer": {
- "device": "cpu",
- "pin_memory": true
- },
- "offload_param": {
- "device": "cpu",
- "pin_memory": true
- }
- },
- "fp16": {
- "enabled": "auto",
- "initial_scale_power": 16
- },
- "bf16": {
- "enabled": "auto"
- }
-}
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_blca.py b/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_blca.py
deleted file mode 100644
index 4b56e8e4e7287b3b422b8a769211549e87084958..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_blca.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- # llm_lora=dict(
- # type=LoraConfig,
- # r=64,
- # lora_alpha=16,
- # lora_dropout=0.1,
- # bias='none',
- # task_type='CAUSAL_LM'),
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='conv',
- kernel_size=5,
- stride=5
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_brca.py b/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_brca.py
deleted file mode 100644
index db08d2b842f612fb1b1bea04bc5c4b0a5a9b5f88..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_brca.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=True, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
-# # LoRA Config
-# llm_lora=dict(
-# type=LoraConfig,
-# r=64,
-# lora_alpha=16,
-# lora_dropout=0.1,
-# bias='none',
-# task_type='CAUSAL_LM'),
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='conv',
- kernel_size=5,
- stride=5
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_luad.py b/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_luad.py
deleted file mode 100644
index d0e64066c56885922b44c715c700fdc54f0fd4ea..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_1d_conv/stage_2_reducer_1d_conv_luad.py
+++ /dev/null
@@ -1,188 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=True, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
-# # LoRA Config
-# llm_lora=dict(
-# type=LoraConfig,
-# r=64,
-# lora_alpha=16,
-# lora_dropout=0.1,
-# bias='none',
-# task_type='CAUSAL_LM'),
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='conv',
- kernel_size=5,
- stride=5
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_blca.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_blca.py
deleted file mode 100644
index 2ab298edcbd821f8a3f1545efc1b2e566d978015..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_blca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_brca.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_brca.py
deleted file mode 100644
index e8475441b19ad1448553def45e98d4fd988b5579..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_brca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_coad.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_coad.py
deleted file mode 100644
index 6c06d3d33216c3939f88f763c7a2cfe36e3e1f9d..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_coad.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_gbm.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_gbm.py
deleted file mode 100644
index e3d7b262ade276e1d4aad30f627618df2584c782..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_gbm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_hnsc.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_hnsc.py
deleted file mode 100644
index 2f1748f72256de26cb1a555c7799a1165bee719f..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_hnsc.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lgg.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lgg.py
deleted file mode 100644
index ac42da9a416e7c0488d271c7975ae4a55570af06..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lgg.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_luad.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_luad.py
deleted file mode 100644
index aa1cafdb8828b091869d8a3c4db121aaf5ba6c23..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_luad.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lusc.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lusc.py
deleted file mode 100644
index bf5e97ec7aa15c1a5dd8ce2cf914ea25e602ef88..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_lusc.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_read.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_read.py
deleted file mode 100644
index 7c73d2782ef129a350669f14a598b280c700b9af..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_read.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_skcm.py b/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_skcm.py
deleted file mode 100644
index bacd36e3f514772dae8db0c504379a003ed1234a..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_skcm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_ACMIL
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_ACMIL,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- acmil_type = 'ga',
- acmil_tokens = 2000,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_blca.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_blca.py
deleted file mode 100644
index f683a9c37943e58829383b8fad5a6f0c828e36c2..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_blca.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_brca.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_brca.py
deleted file mode 100644
index 40414a216ce1a9b290d33a2e512f672a60e936cb..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_brca.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_coad.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_coad.py
deleted file mode 100644
index 6446236547b02d43e364c7a7c467a79f3b422ed6..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_coad.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_gbm.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_gbm.py
deleted file mode 100644
index 8dcdec1023f200ce72c8e0974f957096a122ff26..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_gbm.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_hnsc.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_hnsc.py
deleted file mode 100644
index eefcaa05387370b28ec49506062f8135eec3e735..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_hnsc.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lgg.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lgg.py
deleted file mode 100644
index 9081d68dc6d92d55afc19dcd82535d4e4f7061f4..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lgg.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_luad.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_luad.py
deleted file mode 100644
index a3ff5883885f827c57293cd2637c4309e6cad956..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_luad.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lusc.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lusc.py
deleted file mode 100644
index 927849a6a002e8be9852c431dd874647ac5f76f5..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_lusc.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_read.py b/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_read.py
deleted file mode 100644
index b237523213a74375a5e1a9b6ac28755a42f170f8..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_attn/stage2_reducer_attn_read.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_blca.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_blca.py
deleted file mode 100644
index c896ea5f184d37c4d9aec391e9f7702eb3c2e7e9..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_blca.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=14,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_brca.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_brca.py
deleted file mode 100644
index a22d98416d654dd54fb0d951334c10f75ef7a4c9..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_brca.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_coad.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_coad.py
deleted file mode 100644
index 180806621dfa95a996e8c7cc6f65c35833e25d02..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_coad.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_gbm.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_gbm.py
deleted file mode 100644
index d8f435936e2de7104eb51fffc918eca5da127023..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_gbm.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_hnsc.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_hnsc.py
deleted file mode 100644
index 02c8c00a264894a99be51caaf306e27ff2830416..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_hnsc.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lgg.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lgg.py
deleted file mode 100644
index 2f52829373f51bff22d0cfb8410443fdfffe30f3..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lgg.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_luad.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_luad.py
deleted file mode 100644
index 8fb69f75feef23552c0c8bb0841a936c6adbf85d..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_luad.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lusc.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lusc.py
deleted file mode 100644
index 4962ef07f9874eb2c2640f09767d626e11d227dd..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_lusc.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_read.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_read.py
deleted file mode 100644
index 8c36e9d9269261a37811c91619f311e5e61e6451..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_read.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_skcm.py b/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_skcm.py
deleted file mode 100644
index 8be832b982fc06b30d0b147df93a2d3ed81c784c..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_dynamic_llava/stage_2_dynamic_llava_skcm.py
+++ /dev/null
@@ -1,225 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data (tumor-specific)
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 10240
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_blca.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_blca.py
deleted file mode 100644
index 799b3c22950929d0ef2d604acd0c41eebea13271..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_blca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_brca.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_brca.py
deleted file mode 100644
index 8633697e4e382e587ca61290656601ecef1f464f..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_brca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_coad.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_coad.py
deleted file mode 100644
index 71207ea9ea0cfa2b248c66fffda44b3a8f26ad69..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_coad.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_gbm.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_gbm.py
deleted file mode 100644
index d01b9723be070d383cb3a6354d2d860754894334..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_gbm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_hnsc.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_hnsc.py
deleted file mode 100644
index 7bc9a21bccb4be6a9053cd37967a4a74a50cd7b8..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_hnsc.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lgg.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lgg.py
deleted file mode 100644
index 4a83710f9e488760db78f4e325997dde8de1b703..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lgg.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_luad.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_luad.py
deleted file mode 100644
index 0bab3cfc6f8e2e9bcaaeab302fea33b2a8809aee..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_luad.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lusc.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lusc.py
deleted file mode 100644
index ad34a01c94798582060f92dbf5a8950081734349..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_lusc.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_read.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_read.py
deleted file mode 100644
index d40f6b2a6c71b9766d2a27f10859678441fe0ec7..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_read.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_skcm.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_skcm.py
deleted file mode 100644
index fad972f0ddb212fce29d8c9ab7e5345a6d238cae..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_skcm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_100.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_100.py
deleted file mode 100644
index 54ed3c6bcf5f0f5c568bafe6e3af012f38c91a85..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_100.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=100, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_1000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_1000.py
deleted file mode 100644
index ea02f07c799f8efa094d8a61c8069cfc8da0e0ba..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_1000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=1000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_200.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_200.py
deleted file mode 100644
index c40f5c5147f9f8ed83d1647850750c259aaf82a6..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_200.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=200, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_2000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_2000.py
deleted file mode 100644
index 154446d448d9cb36dde10c997f1302e69e1256cd..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_2000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_4000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_4000.py
deleted file mode 100644
index 6363379d0554cc65a70735c6c697017eff4d6f62..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_4000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=4000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_500.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_500.py
deleted file mode 100644
index e477f0bec4cec470197d6c8d27c041fcfce39880..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_brca_500.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=500, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_100.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_100.py
deleted file mode 100644
index 0f7f474d308d3558c8ca70c313486cd366a4a650..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_100.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=100, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_1000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_1000.py
deleted file mode 100644
index 2288ec2b7ad776452e4032c7d02668ff19240bed..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_1000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=1000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_200.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_200.py
deleted file mode 100644
index 2b14ebaaecc5a6e0e5e9a2280e9ccb7cbbed4cb1..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_200.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=200, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_2000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_2000.py
deleted file mode 100644
index 6d497d9a254cc44a61840e19d6b93c8098934b5b..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_2000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_4000.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_4000.py
deleted file mode 100644
index 638a5fc80cd1bd88c02c80a7dfe3e00c7b9439c1..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_4000.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=4000, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_500.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_500.py
deleted file mode 100644
index 1e47de7a9c4c97819a828d0b1ac2ccd6d233699b..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_luad_500.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=500, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_skcm_100.py b/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_skcm_100.py
deleted file mode 100644
index 8a9a6bd74424ad3b428a513dcdc78bfd61023939..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers/stage_2_visual_only_fusion_compressor_skcm_100.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=100, # MODIFIED: Grid size now parameterized
- prefusion_layer_num = 2,
- image_only = False, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_blca.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_blca.py
deleted file mode 100644
index 710e479945ea7bac8b25c77f4f7112a1381a9a92..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_blca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 2
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_brca.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_brca.py
deleted file mode 100644
index 99410546fb30259abffddc71beecdd74baac4b9a..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_brca.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_coad.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_coad.py
deleted file mode 100644
index e3606d418961a50289e673257ba839a8ca11d558..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_coad.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_gbm.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_gbm.py
deleted file mode 100644
index e99c079b57acf7c9a2f93eedd42de20144872fbb..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_gbm.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_hnsc.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_hnsc.py
deleted file mode 100644
index f6c447cde1021cd7a3a6e725474864d179515a63..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_hnsc.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lgg.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lgg.py
deleted file mode 100644
index 75c9debc295c04098e2ab12aaa06890a15320ffe..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lgg.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_luad.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_luad.py
deleted file mode 100644
index 48711685b61f7b3740aa1e6c6097e996345db864..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_luad.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lusc.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lusc.py
deleted file mode 100644
index 08d4ce41844c7d4ab369c00ee8a2319106f69eb4..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_lusc.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_read.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_read.py
deleted file mode 100644
index 0a0220b2a91783b052f51b66c8ecd3e1557d7b4e..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_read.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_skcm.py b/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_skcm.py
deleted file mode 100644
index 9c72cd4603cadb4e8b2ec7b888926b281c2c3126..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_skcm.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_Compressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_Compressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000, # 196 patches -> 49 tokens
- compressor_embed_dim = 512,
- prefusion_layer_num = 2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_blca.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_blca.py
deleted file mode 100644
index 17e9fe92d7e2163359d2214363592520dd965ac7..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_blca.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BLCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_brca.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_brca.py
deleted file mode 100644
index 777e8d9fc7aee932a57fc533ffafcfae8573a500..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_brca.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/BRCA.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_coad.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_coad.py
deleted file mode 100644
index 3aa4f1a0b2639eae2df45d2a15c767e1dbd1347c..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_coad.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/COAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_gbm.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_gbm.py
deleted file mode 100644
index aee573f4e7498f5232a978367e4dd30440d90b87..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_gbm.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/GBM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_hnsc.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_hnsc.py
deleted file mode 100644
index 9f359d34f90005599265d9a1bb3a7e55d104f45c..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_hnsc.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/HNSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lgg.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lgg.py
deleted file mode 100644
index 00b3e15e46977be53d0de442885491f95f341e9d..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lgg.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_luad.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_luad.py
deleted file mode 100644
index c493dcc762d78a891ffe1a197af935218d769ba2..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_luad.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUAD.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lusc.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lusc.py
deleted file mode 100644
index ea1f54076271de321e888c7824eba7defc5b45dd..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_lusc.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LUSC.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_read.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_read.py
deleted file mode 100644
index 729bf0756c1caf42f4914c758b211718eb002112..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_read.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/READ.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_skcm.py b/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_skcm.py
deleted file mode 100644
index a0ecb95bcc594f03056e1796fe81eef49a14ba0e..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_skcm.py
+++ /dev/null
@@ -1,176 +0,0 @@
-
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = True, # Set to True for image-only training
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/stage2_swinLongNet_resampler.py b/code/xtuner/configs/slidechat/stage2_swinLongNet_resampler.py
deleted file mode 100644
index 92693ca7214e2a4813587475243c0b3ed9b66f9c..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage2_swinLongNet_resampler.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 51200
-per_image_length = 31240
-sample_type='wsi' # 'wsi'or'image'
-
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 16 # global batch size 8 * 16 = 128
-dataloader_num_workers = 32
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-5 # 把 base lr 设成视觉侧(LongNet/Swin)的 lr
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1.0 # grad clip bump to 1.2
-warmup_ratio = 0.09
-
-# Save
-save_steps = 512
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 128
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/h5_files/TCGA-A7-A0CJ-01Z-00-DX2.h5'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=False,
- train_stage='2', # freeze the llm and longnet
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_resampler_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_resampler_hf/projector/projector.safetensors',
- perceiver_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_resampler_hf/perceiver/perceiver.safetensors',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- # quantization_config=dict(
- # type=BitsAndBytesConfig,
- # load_in_4bit=True,
- # load_in_8bit=False,
- # llm_int8_threshold=6.0,
- # llm_int8_has_fp16_weight=False,
- # bnb_4bit_compute_dtype=torch.bfloat16,
- # bnb_4bit_use_double_quant=True,
- # bnb_4bit_quant_type='nf4'),
- ),
- use_swin_longnet = True,
- use_perceiver_resampler = True,
- perceiver_num_latents=256,
- perceiver_depth=4,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=30000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/missing_slides.csv'
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg=dict(
- # LN / bias 不做 weight decay
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- # custom_keys={
- # 'projector': dict(lr_mult=0.15) # reduce the learning rate of projector to 3e-5
- # }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
-
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-# param_scheduler = [
-# dict(
-# type=LinearLR,
-# start_factor=0.01, # 从 1% 的 lr 慢启动
-# by_epoch=True,
-# begin=0,
-# end=warmup_ratio * max_epochs,
-# convert_to_iter_based=True # 按 iter 计算
-# ),
-# dict(
-# type=CosineAnnealingLR,
-# eta_min=0.0,
-# by_epoch=True,
-# begin=warmup_ratio * max_epochs,
-# end=max_epochs,
-# convert_to_iter_based=True
-# )
-# ]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[
- dict(type=WandbVisBackend, init_kwargs=dict(project='stage2_swin_longnet_resampler_projector_slidechat1'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1.py b/code/xtuner/configs/slidechat/stage_1.py
deleted file mode 100644
index 592cfb8e6e70e20c63d378a4a5c75d61f0285815..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage1_caption.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 51200
-per_image_length = 10240
-sample_type='wsi' # 'wsi'or'image'
-
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 256
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 5e-3
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 1000
-save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4'),
- )
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/stage_1_no_longnet_ori.py b/code/xtuner/configs/slidechat/stage_1_no_longnet_ori.py
deleted file mode 100644
index b6a8a293ad7db34d61b57d538a5ed5561cf22567..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_no_longnet_ori.py
+++ /dev/null
@@ -1,255 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model.llava_no_longnet import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-# from xtuner.model.torchscale.model.create_longnet_for_training import create_longvit_model_fast as create_longnet_vit
-# from xtuner.model.torchscale.model.LongNetVit import gigapath_slide_enc3l1536d
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 15836
-#max_length = 32768
-per_image_length = 10240
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (epoch-based)
-batch_size = 1
-accumulative_counts = 256 # 6 * 44 = 264 or 8 * 32 = 256
-dataloader_num_workers = 10
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.0 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 3
-warmup_ratio = 0.05 # 预热占比(相对 max_iters)
-
-# Save
-save_steps = 4096
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 512
-SYSTEM = ''
-evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
-evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
-
- max_position_embeddings = None, # original 32000 +
- #only use
- enable_token_merge = False,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_perceiver_resampler=False,
-
- concat_text_to_queries = True,
- perceiver_num_latents=per_image_length,
- perceiver_depth=4
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- # reduce length to 10240
- sample_num=per_image_length,
- image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides.csv',
- sample_strategy='linspace', #use linspace
-)
-
-
-# cying: add: per_image_length=per_image_length
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg = dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[
- dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_swin_longnet_slidechat4'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_1_no_longnet_ori_new.py b/code/xtuner/configs/slidechat/stage_1_no_longnet_ori_new.py
deleted file mode 100644
index e0d5ee279f4e6232fcd2dabaa0a0781aeada1f71..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_no_longnet_ori_new.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model.llava_only_projector import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-# from xtuner.model.torchscale.model.create_longnet_for_training import create_longvit_model_fast as create_longnet_vit
-# from xtuner.model.torchscale.model.LongNetVit import gigapath_slide_enc3l1536d
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph2.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-# max_length = 15836
-# per_image_length = 10240
-max_length = 30720
-per_image_length = 26624
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (epoch-based)
-batch_size = 1
-accumulative_counts = 256 # 5 * 410 = 2000
-dataloader_num_workers = 10
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.0 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 2
-warmup_ratio = 0.05 # 预热占比(相对 max_iters)
-
-# Save
-save_steps = 4096
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 512
-SYSTEM = ''
-evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
-evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
-
- max_position_embeddings = None, # original 32000 +
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- # reduce length to 10240
- sample_num=per_image_length,
- image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides.csv',
- sample_strategy='linspace', #use linspace
-)
-
-
-# cying: add: per_image_length=per_image_length
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg = dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
-
- paramwise_cfg=dict(
- custom_keys={'^projector\\.': dict(lr_mult=1.0)},
- # 关键:明确只收集 projector,其他丢弃
- # 有些实现没有这个开关;那就用 EnsureProjectorInOptimHook 热修
- )
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-# param_scheduler = [
-
-# dict(
-# type = ConstantLR,
-# by_epoch = True,
-# begin = 0,
-# end = max_epochs,
-# convert_to_iter_based=True
-# )
-# ]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-# visualizer = dict(
-# type=Visualizer,
-# vis_backends=[
-# dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_only_projector1'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_1_resampler.py b/code/xtuner/configs/slidechat/stage_1_resampler.py
deleted file mode 100644
index d610a17e07b2ecff0ff157fe65fd826222fcb758..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_resampler.py
+++ /dev/null
@@ -1,252 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHookResampler, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model.llava_no_longnet_simple_sampler import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph2.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 15836
-per_image_length = 10240
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (epoch-based)
-batch_size = 1
-accumulative_counts = 256 # 8 * 256 = 2048
-dataloader_num_workers = 10
-seed = 42
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.0 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 2
-warmup_ratio = 0.08 # 预热占比(相对 max_iters)
-
-# Save
-save_steps = 5120
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 512
-SYSTEM = ''
-evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
-evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
-
- max_position_embeddings = None, # original 32000 +
- enable_token_merge = True,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_resampler=True,
- resampler_num_latents=100,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=per_image_length,
- image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides3.csv',
- sample_strategy='linspace', #use linspace
-)
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg = dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- paramwise_cfg=dict(
- custom_keys={'^projector\\.': dict(lr_mult=1.0)},
- # 关键:明确只收集 projector,其他丢弃
- # 有些实现没有这个开关;那就用 EnsureProjectorInOptimHook 热修
- ),
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHookResampler,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-# visualizer = dict(
-# type=Visualizer,
-# vis_backends=[
-# dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_no_longnet_simple_resampler_projector100'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=seed, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_1_swinLongNet.py b/code/xtuner/configs/slidechat/stage_1_swinLongNet.py
deleted file mode 100644
index dd45339fe1627ef724f6ae59bbbee36383375d63..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_swinLongNet.py
+++ /dev/null
@@ -1,256 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage1_caption.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 51200
-per_image_length = 31240
-sample_type='wsi' # 'wsi'or'image'
-
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 32 # global batch size 8 * 32 = 256
-dataloader_num_workers = 32
-max_epochs = 4
-optim_type = AdamW
-lr = 1e-3 # 把 base lr 设成视觉侧(LongNet/Swin)的 lr
-betas = (0.9, 0.999)
-weight_decay = 0.05
-max_norm = 1.0 # grad clip bump to 1.2
-warmup_ratio = 0.09
-
-# Save
-save_steps = 512
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 128
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/h5_files/TCGA-A7-A0CJ-01Z-00-DX2.h5'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-bnb = dict(type=BitsAndBytesConfig,
- load_in_4bit=True, # or load_in_8bit=True
- bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- # long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/LongNet_encoder/longnet_encoder.safetensors',
- # projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/projector/projector.safetensors',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- # device_map = 'cuda',
- quantization_config=None
- ),
- use_swin_longnet = True,
- use_perceiver_resampler = False,
- perceiver_num_latents=64,
- perceiver_depth=2,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=31240,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/missing_slides.csv'
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg=dict(
- # LN / bias 不做 weight decay
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- custom_keys={
- # projector 需要更大学习率(~1e-3):3e-5 * 40 = 1.2e-3
- # 视觉主干(含 patch merging / Swin blocks):用 base lr
- 'LongNet_encoder'
- : dict(lr_mult=0.2),
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-# param_scheduler = [
-
-# dict(
-# type=CosineAnnealingLR,
-# eta_min=0.0,
-# by_epoch=True,
-# begin=0,
-# end=max_epochs,
-# convert_to_iter_based=True)
-# ]
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[
- dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_swin_longnet_slidechat4'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1_swinLongNet2.py b/code/xtuner/configs/slidechat/stage_1_swinLongNet2.py
deleted file mode 100644
index 6a311bb25d7be3b811468d8632a1c1f6fa5f4baa..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_swinLongNet2.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler, InfiniteSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook, EMAHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from xtuner.engine.runner import TrainLoop
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
- ThroughputHook
- )
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-llm_name_or_path = 'Qwen/Qwen3-8B'
-
-# —— 把 merge 脚本生成的文件填进来 ——
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/merged_dataset/stage1_morph.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 21200
-per_image_length = 10240
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (iter-based)
-batch_size = 1
-accumulative_counts = 160 # 7 * 160 = 1120
-dataloader_num_workers = 32
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.01 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 3
-warmup_ratio = 0.20 # 预热占比(相对 max_iters)
-
-# Save / Eval
-evaluation_freq = 512
-save_steps = 1024
-save_total_limit = 40
-
-# Eval inputs
-SYSTEM = ''
-evaluation_images = ['/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/h5_files/TCGA-A7-A0CJ-01Z-00-DX2.h5',
- '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/skcm_224x224_b20_t15/h5_files/TCGA-EE-A3AG-01Z-00-DX1.h5'
- ]
-evaluation_inputs = [
- 'Can you provide an overall description of the whole slide image (WSI)?',
- 'Which features are most pivotal for diagnosis in this slide?'
-]
-
-# The whole slide image shows a marked disruption of normal tissue architecture due to infiltrating carcinoma cells.
-# There is a complete disorganization, with the loss of normal glandular or tubular structures,
-# as evidenced by very poor tubular formation.
-# The carcinoma cells exhibit significant nuclear pleomorphism,
-# presenting a wide variation in nuclear size and shape. Mitotic figures are moderately present throughout the slide.
-# Additionally, calcifications can be seen within the tissue.
-# The areas displaying ductal carcinoma in situ (DCIS) exhibit both solid and cribriform patterns, with varying nuclear grades and focal central necrosis.
-# There is a notable extensive intraductal component present.
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
- use_swin_longnet=True,
- longnet_pe_gate_ratio = 0.25,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_perceiver_resampler=False,
- perceiver_num_latents=256,
- perceiver_depth=4
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=per_image_length, # 与 per_image_length 保持一致
- image_feature_prefix='/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/data/qingq/PathVLM/baselines/github/SlideChat/merged_dataset/missing_slides.csv',
- sample_strategy='linspace', #use linspace
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=AdamW,
- lr=lr, betas=betas, weight_decay=weight_decay,
- ),
- paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-param_scheduler = [
- dict( # 线性预热(按 iter)
- type=LinearLR,
- start_factor=0.001,
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True
- ),
- dict( # 余弦退火(按 iter)
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-# param_scheduler = [
-# dict( # 余弦退火(按 iter)
-# type=CosineAnnealingLR,
-# eta_min=0.0,
-# by_epoch=True,
-# begin=0,
-# end=max_epochs,
-# convert_to_iter_based=True
-# )
-# ]
-
-# 迭代制训练循环
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- # 可选:EMA 平滑
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_swin_longnet_qwen3_filtered_wsi_llava'))]
-)
-# visualizer = None
-
-log_level = 'INFO'
-load_from = "/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/stage1_swin_longnet_qwen3_wsi_llava3_/iter_8192_model_only"
-resume = False
-
-randomness = dict(seed=seed, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler.py b/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler.py
deleted file mode 100644
index 16779799c961b5803dc53068aafec9366f5f34f6..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler.py
+++ /dev/null
@@ -1,254 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage1_caption.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 51200
-per_image_length = 31240
-sample_type='wsi' # 'wsi'or'image'
-
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 32 # global batch size 8 * 16 = 128
-dataloader_num_workers = 32
-max_epochs = 3
-optim_type = AdamW
-lr = 2e-4 # 把 base lr 设成视觉侧(LongNet/Swin)的 lr
-betas = (0.9, 0.999)
-weight_decay = 0.0
-max_norm = 1.0 # grad clip bump to 1.2
-warmup_ratio = 0.09
-
-# Save
-save_steps = 512
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 128
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(type=BitsAndBytesConfig,
- load_in_4bit=True, # or load_in_8bit=True
- bnb_4bit_compute_dtype=torch.bfloat16, # or torch.float16
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='0', # freeze the llm and longnet
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_30k_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_30k_hf/projector/projector.safetensors',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- device_map="auto",
- torch_dtype=torch.bfloat16,
- quantization_config=bnb),
- use_swin_longnet = True,
- use_perceiver_resampler = True,
- perceiver_num_latents=256,
- perceiver_depth=4,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=30000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/missing_slides.csv'
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- # paramwise_cfg=dict(
- # # LN / bias 不做 weight decay
- # norm_decay_mult=0.0,
- # bias_decay_mult=0.0,
- # custom_keys={
- # 'projector': dict(lr_mult=0.15), # reduce the learning rate of projector to 3e-5
- # 'LongNet_encoder': dict(lr_mult=0.15)
- # }
- # ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-# param_scheduler = [
-
-# dict(
-# type=CosineAnnealingLR,
-# eta_min=0.0,
-# by_epoch=True,
-# begin=0,
-# end=max_epochs,
-# convert_to_iter_based=True)
-# ]
-
-# param_scheduler = [
-# dict(
-# type=LinearLR,
-# start_factor=0.01, # 从 1% 的 lr 慢启动
-# by_epoch=True,
-# begin=0,
-# end=warmup_ratio * max_epochs,
-# convert_to_iter_based=True # 按 iter 计算
-# ),
-# dict(
-# type=CosineAnnealingLR,
-# eta_min=0.0,
-# by_epoch=True,
-# begin=warmup_ratio * max_epochs,
-# end=max_epochs,
-# convert_to_iter_based=True
-# )
-# ]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[
- dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_frozen_swin_longnet_resampler_projector_slidechat1'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler2.py b/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler2.py
deleted file mode 100644
index 14c7c5f9c04eae3a3aad6c54c31740a82b4d4339..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler2.py
+++ /dev/null
@@ -1,242 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler, InfiniteSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook, EMAHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from mmengine.runner import IterBasedTrainLoop
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
- ThroughputHook
- )
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-llm_name_or_path = 'Qwen/Qwen3-8B'
-
-# —— 把 merge 脚本生成的文件填进来 ——
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_stage1_wsi-llava_mix_.json'
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/merged_dataset/stage1_morph.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 21200
-per_image_length = 10240
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (iter-based)
-batch_size = 1
-accumulative_counts = 32 # 8 * 32 = 256
-dataloader_num_workers = 32
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.01 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 iteration 为主
-max_iters = 25000
-warmup_ratio = 0.25 # 预热占比(相对 max_iters)
-warmup_iters = 2500 # = 7500
-
-# Save / Eval
-evaluation_freq = 512
-save_steps = 1024
-save_total_limit = 10
-
-# Eval inputs
-SYSTEM = ''
-evaluation_images = ['/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/h5_files/TCGA-A7-A0CJ-01Z-00-DX2.h5',
- '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/skcm_224x224_b20_t15/h5_files/TCGA-EE-A3AG-01Z-00-DX1.h5'
- ]
-evaluation_inputs = [
- 'Craft a comprehensive outline capturing the key findings of the pathology report based on the whole slide image.',
- 'Craft a comprehensive outline capturing the key findings of the pathology report based on the whole slide image.'
-]
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["lm_head"],
- llm_int8_skip_modules=["lm_head"],
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='0',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
- use_swin_longnet=True,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_perceiver_resampler=True,
- perceiver_num_latents=256,
- perceiver_depth=4,
- long_net_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_skcm_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth = "/data/qingq/PathVLM/baselines/github/SlideChat/models/stage1_swin_longnet_skcm_hf/projector/projector.safetensors"
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=per_image_length, # 与 per_image_length 保持一致
- image_feature_prefix='/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/data/qingq/PathVLM/baselines/github/SlideChat/dataset/missing_slides.csv',
- sample_strategy='random',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=False,
- dataset=llava_dataset,
- sampler=dict(type=InfiniteSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg=dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- custom_keys={
- # 'LongNet_encoder': dict(lr_mult=0.3), # LongNet 稍慢
- 'projector': dict(lr_mult=0.2),
- # 'perceiver.query_pos': dict(lr_mult=5)
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- # dict( # 线性预热(按 iter)
- # type=LinearLR,
- # start_factor=0.001,
- # by_epoch=False,
- # begin=0,
- # end=warmup_iters
- # ),
- dict( # 余弦退火(按 iter)
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=False,
- begin=0,
- end=max_iters
- )
-]
-
-# 迭代制训练循环
-train_cfg = dict(type=IterBasedTrainLoop, max_iters=max_iters)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- # 可选:EMA 平滑
- # dict(type=EMAHook,
- # ema_type='ExponentialMovingAverage',
- # momentum=0.0002,
- # update_buffers=True,
- # strict_load=False,
- # # 在预热结束后再启用 EMA
- # begin_iter=warmup_iters
- # ),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_swin_longnet_perceiver_qwen3_wsi_llava'))]
-)
-visualizer = None
-
-log_level = 'INFO'
-load_from = None
-resume = False
-
-randomness = dict(seed=seed, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler_long.py b/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler_long.py
deleted file mode 100644
index 696ee085705cad8f555ffa2fa8fda8ec97a0b0bd..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_swinLongNet_resampler_long.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler, InfiniteSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook, EMAHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from mmengine.runner import IterBasedTrainLoop
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
- ThroughputHook
- )
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-llm_name_or_path = 'Qwen/Qwen3-8B'
-
-# —— 把 merge 脚本生成的文件填进来 ——
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_stage1_wsi-llava_mix_.json'
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_stage1_wsi-llava_mix_skcm.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 32000
-per_image_length = 20480
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (iter-based)
-batch_size = 1
-accumulative_counts = 32 # 8 * 32 = 256
-dataloader_num_workers = 32
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.01 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 iteration 为主
-max_iters = 15000
-warmup_ratio = 0.25 # 预热占比(相对 max_iters)
-warmup_iters = 2500 # = 7500
-
-# Save / Eval
-evaluation_freq = 512
-save_steps = 1024
-save_total_limit = 10
-
-# Eval inputs
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/h5_files/TCGA-A7-A0CJ-01Z-00-DX2.h5'
-evaluation_inputs = [
- 'Craft a comprehensive outline capturing the key findings of the pathology report based on the whole slide image.'
-]
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["lm_head"],
- llm_int8_skip_modules=["lm_head"],
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
- use_swin_longnet=True,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_perceiver_resampler=True,
- perceiver_num_latents=256,
- perceiver_depth=4,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=per_image_length, # 与 per_image_length 保持一致
- image_feature_prefix='/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/data/qingq/PathVLM/baselines/github/SlideChat/dataset/missing_slides.csv',
- sample_strategy='random',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=False,
- dataset=llava_dataset,
- sampler=dict(type=InfiniteSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg=dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
- custom_keys={
- 'LongNet_encoder': dict(lr_mult=0.2), # LongNet 稍慢
- # 'perceiver': dict(lr_mult = 0.2) # perceiver 慢
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict( # 线性预热(按 iter)
- type=LinearLR,
- start_factor=0.001,
- by_epoch=False,
- begin=0,
- end=warmup_iters
- ),
- dict( # 余弦退火(按 iter)
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=False,
- begin=warmup_iters,
- end=max_iters
- )
-]
-
-# 迭代制训练循环
-train_cfg = dict(type=IterBasedTrainLoop, max_iters=max_iters)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- # 可选:EMA 平滑
- # dict(type=EMAHook,
- # ema_type='ExponentialMovingAverage',
- # momentum=0.0002,
- # update_buffers=True,
- # strict_load=False,
- # # 在预热结束后再启用 EMA
- # begin_iter=warmup_iters
- # ),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(type=CheckpointHook, by_epoch=False, interval=save_steps, max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_swin_longnet_perceiver_qwen3_wsi_llava_skcm_long'))]
-)
-# visualizer = None
-
-log_level = 'INFO'
-load_from = None
-resume = False
-
-randomness = dict(seed=seed, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|ntok|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_1_token_merge.py b/code/xtuner/configs/slidechat/stage_1_token_merge.py
deleted file mode 100644
index d8414eb83776cc3aeee102d5ade81936979971ea..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_1_token_merge.py
+++ /dev/null
@@ -1,250 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model.llava_no_longnet import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-# from xtuner.model.torchscale.model.create_longnet_for_training import create_longvit_model_fast as create_longnet_vit
-# from xtuner.model.torchscale.model.LongNetVit import gigapath_slide_enc3l1536d
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-data_path = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/stage1_morph2.json'
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 15836
-per_image_length = 10240
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (epoch-based)
-batch_size = 1
-accumulative_counts = 400 # 5 * 400 = 2000
-dataloader_num_workers = 5
-seed = 2025
-optim_type = AdamW
-lr = 1e-3
-betas = (0.9, 0.999)
-weight_decay = 0.0 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 2
-warmup_ratio = 0.05 # 预热占比(相对 max_iters)
-
-# Save
-save_steps = 5120
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 512
-SYSTEM = ''
-evaluation_images = '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5'
-evaluation_inputs = ['Are the tumor cells organized in a lobulated pattern within the slide?']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='1',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
-
- max_position_embeddings = None, # original 32000 +
- enable_token_merge = True,
- # 建议:前期用 resampler 更稳(也更省显存);如不需要可改回 False
- use_perceiver_resampler=False,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=per_image_length,
- image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides3.csv',
- sample_strategy='linspace', #use linspace
-)
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg = dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
-
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-visualizer = dict(
- type=Visualizer,
- vis_backends=[
- dict(type=WandbVisBackend, init_kwargs=dict(project='stage1_no_longnet_with_language_guide2'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_2.py b/code/xtuner/configs/slidechat/stage_2.py
deleted file mode 100644
index 2a90695daf4103d3ec96f60e0439f04bc5a8980a..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json' #'slidechat_train_vqa_stage2.json'
-
-## specific tumor
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth' #'stage1_pth'
-# pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_orignal_1024maxlength_freezelongnet/iter_4066.pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 19600
-per_image_length = 196 #None
-sample_type='wsi' # 'wsi'or'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 1
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 500
-save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv' #'./BLCA/TCGA-GV-A40G-01Z-00-DX1.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True, #False,
- pretrained_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth', #pretrained_pth,
- # long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- # projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/pytorch_model.bin',
- train_stage='2',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- # attn_implementation=None, # 👈 disable FlashAttention
- # device_map='auto',
- # quantization_config=dict( # 量化配置(保留则为 4 比特,删除则为正常浮点)
- # type=BitsAndBytesConfig,
- # load_in_4bit=True,
- # load_in_8bit=False,
- # llm_int8_threshold=6.0,
- # llm_int8_has_fp16_weight=False,
- # bnb_4bit_compute_dtype=torch.float16,
- # bnb_4bit_use_double_quant=True,
- # bnb_4bit_quant_type='nf4'),
- # ),
-
- # llm_lora=dict( # LoRA 配置(保留则使用 LoRA 微调,删除则使用全量微调)
- # type=LoraConfig,
- # r=64,
- # lora_alpha=16,
- # lora_dropout=0.1,
- # bias='none',
- # task_type='CAUSAL_LM')
- )
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=2000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_divprune.py b/code/xtuner/configs/slidechat/stage_2_divprune.py
deleted file mode 100644
index ad288c2e7e9f9689598bc225dd759d582f3447b1..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_divprune.py
+++ /dev/null
@@ -1,237 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_DivPrune
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json' #'slidechat_train_vqa_stage2.json'
-
-## specific tumor
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth' #'stage1_pth'
-# pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_orignal_1024maxlength_freezelongnet/iter_4066.pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 19600
-per_image_length = 196 #None
-sample_type='wsi' # 'wsi'or'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 500
-save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv' #'./BLCA/TCGA-GV-A40G-01Z-00-DX1.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel_DivPrune,
- freeze_llm=True, #False,
- pretrained_pth=pretrained_pth,
- train_stage='2',
- divprune_ratio = 0.50,
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
-
- # attn_implementation=None, # 👈 disable FlashAttention
-
- # quantization_config=dict( # 量化配置(保留则为 4 比特,删除则为正常浮点)
- # type=BitsAndBytesConfig,
- # load_in_4bit=True,
- # load_in_8bit=False,
- # llm_int8_threshold=6.0,
- # llm_int8_has_fp16_weight=False,
- # bnb_4bit_compute_dtype=torch.float16,
- # bnb_4bit_use_double_quant=True,
- # bnb_4bit_quant_type='nf4'),
-
- # lora=dict( # LoRA 配置(保留则使用 LoRA 微调,删除则使用全量微调)
- # type=LoraConfig,
- # r=64,
- # lora_alpha=16,
- # lora_dropout=0.1,
- # bias='none',
- # task_type='CAUSAL_LM')
- )
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(by_epoch=False)
diff --git a/code/xtuner/configs/slidechat/stage_2_dynamic_qlora.py b/code/xtuner/configs/slidechat/stage_2_dynamic_qlora.py
deleted file mode 100644
index 8a0fc23dc6535f45294deb30e83457d8204ce85f..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_dynamic_qlora.py
+++ /dev/null
@@ -1,294 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel, DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json' #'slidechat_train_vqa_stage2.json'
-
-## specific tumor
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json'
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth' #'stage1_pth'
-# pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_orignal_1024maxlength_freezelongnet/iter_4066.pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 19600
-per_image_length = 10240 #None
-sample_type='wsi' # 'wsi'or'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 100
-save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv' #'./BLCA/TCGA-GV-A40G-01Z-00-DX1.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-# from transformers import BitsAndBytesConfig
-# use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
-# half_dtype = torch.bfloat16 if use_bf16 else torch.float16
-
-model = dict(
- type=DynamicLLaVAQwen25,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf/projector/projector.safetensors',
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type = DynamicQwen2ForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='flash_attention_2', # 'flash_attention_2', # 👈 disable FlashAttention
- torch_dtype= torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-
- # NEW: keep predictors & lm_head in fp16
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
- projector_depth=2,
- enable_long_net=True,
- hidden_size=512,
- image_feature_length=196,
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- # target_modules="all-linear",
- bias="none",
- task_type="CAUSAL_LM",
- ),
- sparse_config=dict(
- use_vision_predictor=True,
- use_text_predictor=False,
- use_output_text_predictor=False,
- use_instruct_predictor=False,
- sparse_layer=15,
- d_model=512,
- nhead=8,
- vision_keep_rate=0.5,
- dim_feedforward=768,
- num_layers=2,
- output_text_keep_rate=0.8,
- instruct_keep_rate=0.9,
- mask_loss_weight=0.1,
- output_text_len_for_training=100,
- instruct_len_for_training=100,
- ),
- divprune_ratio=0.2, # NEW: divprune ratio
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
-
- # 'image_score_predictor': dict(lr_mult=5.0),
- # 'output_text_score_predictor': dict(lr_mult=5.0),
- # 'instruct_score_predictor': dict(lr_mult=5.0),
-
- # # (optional) if your build sometimes names them differently:
- # 'visionpredictor': dict(lr_mult=5.0),
- # 'textpredictor': dict(lr_mult=5.0),
- # 'instructpredictor': dict(lr_mult=5.0),
- }),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True,
- # ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=PrintInitLRHook, show_group_summary=True),
-
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- # dict(
- # type=EvaluateChatHook,
- # tokenizer=tokenizer,
- # every_n_iters=evaluation_freq,
- # evaluation_inputs=evaluation_inputs,
- # evaluation_images=evaluation_images,
- # system=SYSTEM,
- # prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
-# log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_2_fastv.py b/code/xtuner/configs/slidechat/stage_2_fastv.py
deleted file mode 100644
index bf582c63074b7d2ef5ace5026c8a8b877fce4eef..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_fastv.py
+++ /dev/null
@@ -1,161 +0,0 @@
-#!/usr/bin/env python3
-"""
-Stage-2 training config using Qwen2Model (from fastv_qwen.py) directly
-and enabling FastV token pruning. The LLM dict is respected and
-instantiated inside Qwen2Model.__init__.
-"""
-
-import torch
-from torch.optim import AdamW
-from transformers import AutoTokenizer, BitsAndBytesConfig
-
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, ThroughputHook, PrintInitLRHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-
-# Import our custom model (aliases provided in fastv_qwen.py)
-from xtuner.model import Qwen25ModelFastV, Qwen25ForCausalLMFastV
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-llm_name_or_path = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_hf'
-fastv_k = 2
-fastv_r = 0.5
-use_fastv = True
-per_image_length = 10240
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-data_path = (
- '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/'
- 'stage2_vqa_tumor_/SKCM.json'
-)
-
-max_length = 19600
-sample_type = 'wsi'
-
-# Train hyperparams
-batch_size = 1
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 2
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-warmup_ratio = 0.03
-
-# Saving / Eval
-save_steps = 100
-save_total_limit = 3
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = [
- 'Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.'
-]
-
-#######################################################################
-# PART 2 Tokenizer #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right',
-)
-
-#######################################################################
-# PART 3 Model — Qwen2Model + LLM wrapper #
-#######################################################################
-# Pass a plain dict for config; Qwen2Model.__init__ will normalize it
-config_dict = dict(
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- # Attention impl so we can request attentions for FastV
- attn_implementation='eager',
-
- # FastV knobs read by Qwen2Model.reset_fastv()
- use_fast_v=bool(use_fastv),
- fast_v_agg_layer=int(fastv_k),
- fast_v_sys_length=0,
- fast_v_image_token_length=int(per_image_length),
- fast_v_attention_rank=max(1, int(round(per_image_length * float(fastv_r)))),
- fast_v_inplace=True,
-)
-
-model = dict(
- type=Qwen25ModelFastV, # class object
- config=config_dict, # plain dict (no runtime calls here)
-
- # ✅ LLM dict WILL be instantiated by Qwen2Model.__init__
- llm=dict(
- type=Qwen25ForCausalLMFastV.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- attn_implementation='eager',
- torch_dtype=torch.bfloat16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- llm_int8_skip_modules=["image_score_predictor", "output_text_score_predictor", "instruct_score_predictor", "lm_head"],
- ),
- ),
-
- freeze_llm=True,
- pretrained_pth=None,
- projector_depth=2,
- # llm_lora=dict(
- # type=LoraConfig,
- # r=512,
- # lora_alpha=256,
- # lora_dropout=0.05,
- # bias='none',
- # task_type='CAUSAL_LM',
- # ),
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- enable_long_net=True,
- long_net_pth=(
- '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/'
- 'stage2_hf/LongNet_encoder/longnet_encoder.safetensors'
- ),
- projector_pth=(
- '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/'
- 'stage2_hf/projector/projector.safetensors'
- ),
- image_feature_length=196,
-)
-
-#######################################################################
-# PART 4 Quantization (optional top-level) #
-#######################################################################
-quantization_config = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4',
- skip_modules=['lm_head'],
- llm_int8_skip_modules=['lm_head'],
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_fusion_compressor.py b/code/xtuner/configs/slidechat/stage_2_fusion_compressor.py
deleted file mode 100644
index 85eeaae9d64e75622fdbd80c09dfee2d969f18e3..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_fusion_compressor.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 1
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=2000,
- prefusion_layer_num = 2,
- image_only = False,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_fusion_compressor_500.py b/code/xtuner/configs/slidechat/stage_2_fusion_compressor_500.py
deleted file mode 100644
index b393f31deb1c8bda0b708e6e845ef4e7a8b33fb5..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_fusion_compressor_500.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel_FusionCompressor
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 1
-dataloader_num_workers = 64
-max_epochs = 1
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 1000
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModel_FusionCompressor,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- compressor_grid_size=100,
- prefusion_layer_num = 2,
- image_only = False,
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_pure.py b/code/xtuner/configs/slidechat/stage_2_pure.py
deleted file mode 100644
index 8177936b12524cbc84ed4fb53f93989509540de8..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_pure.py
+++ /dev/null
@@ -1,280 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.visualization import Visualizer, WandbVisBackend
-from mmengine.optim import AmpOptimWrapper, ConstantLR, LinearLR, CosineAnnealingLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model.llava_only_projector import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-# from xtuner.model.torchscale.model.create_longnet_for_training import create_longvit_model_fast as create_longnet_vit
-# from xtuner.model.torchscale.model.LongNetVit import gigapath_slide_enc3l1536d
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/merged_dataset/stage2_tasks_plus_report2.json'
-data_path = "/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/dataset/all_data.json"
-image_path_list = None
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-# 长序列:保持 per_image_length == sample_num
-max_length = 26076
-per_image_length = 21980
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer (epoch-based)
-batch_size = 1
-accumulative_counts = 256 # 8 * 256 = 512
-dataloader_num_workers = 10
-seed = 42
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0.01 # 适度WD抑制漂移
-max_norm = 1 # 更紧的梯度裁剪
-
-# 以 epoch 为主
-max_epochs = 2
-warmup_ratio = 0.05 # 预热占比(相对 max_iters)
-
-# Save
-save_steps = 4096
-save_total_limit = 8 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 512
-SYSTEM = ''
-evaluation_images = ['/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EB-A5UN-06Z-00-DX1.h5',
- '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/skcm_224x224_b20_t15/h5_files/TCGA-EE-A3AG-01Z-00-DX1.h5',
- '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/lusc_224x224_b20_t15/h5_files/TCGA-NC-A5HP-01Z-00-DX1.h5',
- '/mnt/bn/xudong-va/meilong/datasets/Token_Compression/ucec_224x224_b20_t15/h5_files/TCGA-AJ-A3TW-01Z-00-DX1.h5'
- ]
-evaluation_inputs = [
- 'Are the tumor cells organized in a lobulated pattern within the slide?',
- 'Craft a comprehensive outline capturing the key findings of the pathology report based on the whole slide image.',
- 'Based on the observed features, what do you think is the correct histological classification of the tumor? A) Poorly differentiated keratinizing squamous cell carcinoma B) Moderately differentiated squamous cell carcinoma C) Well-differentiated squamous cell carcinoma D) Adenocarcinoma',
- 'Based on the IHC results showing diffuse positivity for p53, consistent with serous carcinoma, what is the molecular subtype in this uterine carcinoma?'
-]
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-bnb = dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
-)
-
-llm_lora=dict(
- type=LoraConfig,
- r=256,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
-)
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True,
- train_stage='2',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.bfloat16,
- attn_implementation='flash_attention_2',
- quantization_config=bnb
- ),
-
- pretrained_pth = '/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/checkpoints/stage_1/iter_23096.pth',
- # 这里要替换成 stage 1 训练好的权重
- projector_pth=None,
- # projector_pth='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/checkpoints/stage_1/iter_23096.pth',
- # 可以用 pretrained_pth 来加载 stage 1 的整体权重
- max_position_embeddings = None, # original 32000 +
- llm_lora = llm_lora,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- # reduce length to 10240
- sample_num=per_image_length,
- image_feature_prefix='/mnt/bn/xudong-va/meilong/datasets/Token_Compression',
- image_feature_suffix='.h5',
- identifier='_224x224_b20_t15',
- unwanted_prefix_csv='/mnt/bn/yuxuanwang/meilong/code/projects/efficient_foundation_wsi_llava/merged_dataset/missing_slides3.csv',
- sample_strategy='linspace', #use linspace
-)
-
-
-# cying: add: per_image_length=per_image_length
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- persistent_workers=True,
- prefetch_factor=4,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn)
-)
-
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- paramwise_cfg = dict(
- norm_decay_mult=0.0,
- bias_decay_mult=0.0,
-
- paramwise_cfg=dict(
- custom_keys={'^projector\\.': dict(lr_mult=1.0)},
- # 关键:明确只收集 projector,其他丢弃
- # 有些实现没有这个开关;那就用 EnsureProjectorInOptimHook 热修
- )
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=True), # 遇到 NaN 立刻报错
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='bfloat16',
-)
-
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=0.01, # 从 1% 的 lr 慢启动
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True # 按 iter 计算
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True
- )
-]
-
-
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type = ThroughputHook)
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-
-# visualizer = dict(
-# type=Visualizer,
-# vis_backends=[
-# dict(type=WandbVisBackend, init_kwargs=dict(project='stage2_only_projector_all_data'))])
-visualizer = None
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_qlora.py b/code/xtuner/configs/slidechat/stage_2_qlora.py
deleted file mode 100644
index 0c5b2a148d4f1b7a20b1aa3b339d128f8ed7234f..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_qlora.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json' #'slidechat_train_vqa_stage2.json'
-
-## specific tumor
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json'
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth' #'stage1_pth'
-# pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_orignal_1024maxlength_freezelongnet/iter_4066.pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 19600
-per_image_length = 196 #None
-sample_type='wsi' # 'wsi'or'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 1
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 1000
-save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv' #'./BLCA/TCGA-GV-A40G-01Z-00-DX1.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True, #False,
- pretrained_pth=None, #pretrained_pth,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/projector/projector.safetensors',
- train_stage='2',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- # torch_dtype=torch.float16,
- # load_in_4bit=True,
- attn_implementation="sdpa",
- # device_map='auto',
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- ),
- ),
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- )
-)
-
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
-
- }
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True,
- # ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_qlora_d_back.py b/code/xtuner/configs/slidechat/stage_2_qlora_d_back.py
deleted file mode 100644
index 0b74023ae28d0d48469d5da2f48fefe60461cd0d..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_qlora_d_back.py
+++ /dev/null
@@ -1,275 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-from xtuner.utils import PROMPT_TEMPLATE
-from xtuner.utils.dynamic_llava_dispatch import DynamicLlavaPatchHook
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json' #'slidechat_train_vqa_stage2.json'
-
-## specific tumor
-# data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideInstruct_train_stage2_vqa.json'
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth' #'stage1_pth'
-# pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_orignal_1024maxlength_freezelongnet/iter_4066.pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-
-max_length = 19600
-per_image_length = 196 #None
-sample_type='wsi' # 'wsi'or'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 16
-dataloader_num_workers = 64
-max_epochs = 1
-optim_type = AdamW
-lr = 2e-4
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Save
-save_steps = 1000
-save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv' #'./BLCA/TCGA-GV-A40G-01Z-00-DX1.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True, #False,
- pretrained_pth=None, #pretrained_pth,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/projector/projector.safetensors',
- train_stage='2',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- # load_in_4bit=True,
- # attn_implementation=None, # 👈 disable FlashAttention
- # device_map='auto',
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- ),
- ),
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- )
-)
-
-# Put your desired ratios here
-_dyn_cfg = dict(
- mode="gather", # "mask" or "gather"
- predictor="dynamic_llava", # <-- use the new predictors
- vision_d_model=512,
- vision_nhead=8,
- vision_dim_feedforward=2048,
- vision_layers=2,
- text_d_model=512,
- text_layers=2,
- target_keep_ratio_image=0.5,
- target_keep_ratio_instruct=0.5,
- target_keep_ratio_answer=0.8,
- gather_image=True,
- gather_text=False,
- min_tokens_per_region=8,
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000,# max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-# cying: add: per_image_length=per_image_length,
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# optimizer
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay
- ),
- paramwise_cfg=dict(
- custom_keys={
- 'projector': dict(lr_mult=0.1),
- 'LongNet_encoder': dict(lr_mult=0.1),
- }),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-# learning policy
-# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
-param_scheduler = [
- # dict(
- # type=LinearLR,
- # start_factor=1e-5,
- # by_epoch=True,
- # begin=0,
- # end=warmup_ratio * max_epochs,
- # convert_to_iter_based=True,
- # ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-# Log the dialogue periodically during the training process, optional
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
- dict(type=DynamicLlavaPatchHook, dyn_cfg=_dyn_cfg, layers_every=1),
-]
-
-# configure default hooks
-default_hooks = dict(
- # record the time of every iteration.
- timer=dict(type=IterTimerHook),
- # print log every 10 iterations.
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- # enable the parameter scheduler.
- param_scheduler=dict(type=ParamSchedulerHook),
- # save checkpoint per `save_steps`.
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- # set sampler seed in distributed evrionment.
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-# configure environment
-env_cfg = dict(
- # whether to enable cudnn benchmark
- cudnn_benchmark=False,
- # set multi process parameters
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- # set distributed parameters
- dist_cfg=dict(backend='nccl'),
-)
-
-# set visualizer
-visualizer = None
-
-# set log level
-log_level = 'INFO'
-
-# load from which checkpoint
-load_from = None
-
-# whether to resume training from the loaded checkpoint
-resume = False
-
-# Defaults to use random seed and disable `deterministic`
-randomness = dict(seed=None, deterministic=False)
-
-# set log processor
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
diff --git a/code/xtuner/configs/slidechat/stage_2_qlora_muon.py b/code/xtuner/configs/slidechat/stage_2_qlora_muon.py
deleted file mode 100644
index 8bf1281b8cde577455efac09708a83bff65299d6..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_qlora_muon.py
+++ /dev/null
@@ -1,282 +0,0 @@
-# Modified SlideChat configuration using Muon optimizer
-#
-# This configuration demonstrates how to inject the Muon optimizer into
-# SlideChat's training pipeline. The original configuration used AdamW;
-# we import the Muon optimizer, set it as the optimizer type and define
-# Muon‑specific hyperparameters.
-custom_imports = dict(
- imports=['xtuner.engine.optimizer.muon_wrapper'],
- allow_failed_imports=False,
-)
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook, ThroughputHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModel
-# from xtuner.engine import MuonOptimWrapperConstructor
-from xtuner.utils import PROMPT_TEMPLATE
-from peft import LoraConfig
-
-from muon import MuonWithAuxAdam
-
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-
-max_length = 19600
-per_image_length = 196 # None
-sample_type = 'wsi' # 'wsi' or 'image'
-
-# Scheduler & Optimizer
-batch_size = 1 # per_device
-accumulative_counts = 12
-dataloader_num_workers = 64
-max_epochs = 1
-
-# # Use Muon instead of AdamW.
-# optim_type = MuonWithAuxAdam
-
-# Learning rate for all parameters. When using Muon with parameter
-# groups, you will typically set a higher LR for hidden weights and a
-# lower LR for non‑hidden weights. For simplicity this example keeps a
-# single learning rate; tune as needed.
-lr = 1.5e-4
-
-# Standard Adam betas are not used by Muon; Muon uses momentum,
-# nesterov and ns_steps instead. You can remove betas or leave it
-# unused. Here we leave betas for completeness.
-betas = (0.9, 0.999)
-
-weight_decay = 0
-max_norm = 1 # grad clip
-warmup_ratio = 0.03
-
-# Muon‑specific hyperparameters. According to the Muon README, good
-# defaults are momentum=0.95, nesterov=True and ns_steps=5
-muon_momentum = 0.95
-muon_nesterov = True
-muon_ns_steps = 5
-
-# Save
-save_steps = 1000
-save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
-
-# Evaluate the generation performance during the training
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-
-model = dict(
- type=LLaVAModel,
- freeze_llm=True, # False,
- pretrained_pth=None, # pretrained_pth,
- long_net_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/LongNet_encoder/longnet_encoder.safetensors',
- projector_pth='/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage1_hf/projector/projector.safetensors',
- train_stage='2',
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- quantization_config=dict(
- type=BitsAndBytesConfig,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- ),
- ),
-
- llm_lora=dict(
- type=LoraConfig,
- r=512,
- lora_alpha=256,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- )
-)
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10000, # max patch number
- image_feature_prefix='/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
-)
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-# Build the Muon optimizer with optional Muon‑specific hyperparameters.
-# When using MuonWithAuxAdam, you can pass momentum, nesterov and ns_steps
-# directly as keyword arguments; these will override the default values
-# inside Muon. If you want to use separate parameter groups (hidden vs
-# non‑hidden weights) you would need to write a custom optim_wrapper
-# constructor; this example uses a single LR for simplicity.
-# Tell MMEngine to use our constructor and pass Muon hparams
-# optim_wrapper = dict(
-# type=AmpOptimWrapper,
-# # this 'optimizer' dict is just a stub; constructor will replace with a real optimizer object
-# optimizer=dict(type=MuonWithAuxAdam),
-# clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
-# accumulative_counts=accumulative_counts,
-# loss_scale='dynamic',
-# dtype='float16',
-# )
-
-# optim_wrapper_constructor = dict(
-# type='MuonOptimWrapperConstructor',
-# muon_hidden_lr=2e-2, # tune if unstable (try 5e-3–1e-2 for QLoRA)
-# muon_weight_decay=1e-2,
-# adamw_lr=lr, # reuse your config's lr for non-hidden
-# adamw_betas=(0.9, 0.95),
-# adamw_weight_decay=1e-2,
-# keep_lora_on_adamw=False, # set True if you want LoRA to stay on AdamW
-# )
-
-
-optim_wrapper = dict(
- type='AmpOptimWrapper', # DS will flip this to DeepSpeedOptimWrapper
- constructor='MuonOptimWrapperConstructor', # <<< IMPORTANT
- # Muon constructor args:
- muon_hidden_lr=0.02,
- muon_weight_decay=0.01,
- adamw_lr=lr, # 1.5e-4 from your cfg
- adamw_betas=(0.9, 0.95),
- adamw_weight_decay=0.01,
- keep_lora_on_adamw=False,
-
- optimizer=dict(
- type=MuonWithAuxAdam # stub; ctor will replace with a real instance
- ),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16',
-)
-
-
-# learning policy
-param_scheduler = [
- dict(
- type=LinearLR,
- start_factor=1e-5,
- by_epoch=True,
- begin=0,
- end=warmup_ratio * max_epochs,
- convert_to_iter_based=True,
- ),
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=warmup_ratio * max_epochs,
- end=max_epochs,
- convert_to_iter_based=True,
- ),
-]
-
-# train, val, test setting
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template),
- dict(type=ThroughputHook),
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-
-log_level = 'INFO'
-
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_2_reducer_1d_conv.py b/code/xtuner/configs/slidechat/stage_2_reducer_1d_conv.py
deleted file mode 100644
index 61b4781d8b6ccc76466fafc784f39667da034eba..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_reducer_1d_conv.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='conv',
- kernel_size=4,
- stride=4
- )
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_2_reducer_attn.py b/code/xtuner/configs/slidechat/stage_2_reducer_attn.py
deleted file mode 100644
index 6b660d76c37e34bd1130497acdbb13ebd2b568be..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_reducer_attn.py
+++ /dev/null
@@ -1,196 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/SKCM.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 4
-optim_type = AdamW
-lr = 2e-5
-betas = (0.9, 0.999)
-weight_decay = 0
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=False, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- llm_lora=dict(
- type=LoraConfig,
- r=64,
- lora_alpha=16,
- lora_dropout=0.1,
- bias='none',
- task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='attention',
- in_tokens=4096,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=4096, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-load_from = None
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(by_epoch=False)
\ No newline at end of file
diff --git a/code/xtuner/configs/slidechat/stage_2_text_reducer_attn.py b/code/xtuner/configs/slidechat/stage_2_text_reducer_attn.py
deleted file mode 100644
index 716cb0befba93b94275c662d587a6d9a3f357c73..0000000000000000000000000000000000000000
--- a/code/xtuner/configs/slidechat/stage_2_text_reducer_attn.py
+++ /dev/null
@@ -1,203 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-from mmengine.dataset import DefaultSampler
-from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
- LoggerHook, ParamSchedulerHook)
-from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
-from torch.optim import AdamW
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel)
-
-from xtuner.dataset import LLaVADataset
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook, HFCheckpointHook
-from xtuner.engine.runner import TrainLoop
-from xtuner.model import LLaVAModelWithReducer
-from xtuner.utils import PROMPT_TEMPLATE
-
-from peft import LoraConfig
-#######################################################################
-# PART 1 Settings #
-#######################################################################
-# Model
-llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
-
-# Data
-data_path = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_/LGG.json'
-image_path_list = None
-pretrained_pth = '/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth'
-
-prompt_template = PROMPT_TEMPLATE.qwen_chat
-max_length = 19600
-per_image_length = 196
-sample_type='wsi'
-
-# Scheduler & Optimizer
-batch_size = 1
-accumulative_counts = 8
-dataloader_num_workers = 64
-max_epochs = 3
-optim_type = AdamW
-lr = 1.5e-5
-betas = (0.9, 0.999)
-weight_decay = 0.05
-max_norm = 1
-
-# Save
-save_steps = 500
-save_total_limit = 3
-
-# Evaluate
-evaluation_freq = 1000
-SYSTEM = ''
-evaluation_images = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
-evaluation_inputs = ['Generate an overview summarizing the principal findings from the pathology examination of the whole slide image.']
-
-#######################################################################
-# PART 2 Model & Tokenizer & Image Processor #
-#######################################################################
-tokenizer = dict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- padding_side='right')
-
-model = dict(
- type=LLaVAModelWithReducer,
- freeze_llm=True, # Set to False to enable LoRA training
- pretrained_pth=pretrained_pth,
- train_stage='2',
-
- llm=dict(
- type=AutoModelForCausalLM.from_pretrained,
- pretrained_model_name_or_path=llm_name_or_path,
- trust_remote_code=True,
- torch_dtype=torch.float16,
- ),
- # LoRA Config
- # llm_lora=dict(
- # type=LoraConfig,
- # r=64,
- # lora_alpha=16,
- # lora_dropout=0.1,
- # bias='none',
- # task_type='CAUSAL_LM'),
-
- # 1d convolution for visual token reduction
- # visual_token_reducer_config=dict(
- # kernel_size = 4,
- # stride = 4)
-
-
- # using MLP for visual token reduction
- visual_token_reducer_config=dict(
- type='text_guided_attention',
- in_tokens=10000,
- out_tokens=2048,
- num_heads=8,
- num_queries=2048
- )
-
- )
-
-#######################################################################
-# PART 3 Dataset & Dataloader #
-#######################################################################
-llava_dataset = dict(
- type=LLaVADataset,
- data_path=data_path,
- image_folder='',
- image_path_list=image_path_list,
- tokenizer=tokenizer,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=prompt_template),
- max_length=max_length,
- per_image_length=per_image_length,
- pad_image_to_square=False,
- sample_num=10240, # max patch number
- image_feature_prefix = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1',
- image_feature_suffix='.pt',
- identifier='_224x224_b20_t15',
- )
-
-
-train_dataloader = dict(
- batch_size=batch_size,
- num_workers=dataloader_num_workers,
- pin_memory=True,
- dataset=llava_dataset,
- sampler=dict(type=DefaultSampler, shuffle=True),
- collate_fn=dict(type=default_collate_fn))
-
-#######################################################################
-# PART 4 Scheduler & Optimizer #
-#######################################################################
-optim_wrapper = dict(
- type=AmpOptimWrapper,
- optimizer=dict(
- type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
- clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
- accumulative_counts=accumulative_counts,
- loss_scale='dynamic',
- dtype='float16')
-
-param_scheduler = [
- dict(
- type=CosineAnnealingLR,
- eta_min=0.0,
- by_epoch=True,
- begin=0,
- end=max_epochs,
- convert_to_iter_based=True)
-]
-
-train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
-
-#######################################################################
-# PART 5 Runtime #
-#######################################################################
-custom_hooks = [
- dict(type=DatasetInfoHook, tokenizer=tokenizer),
- dict(
- type=EvaluateChatHook,
- tokenizer=tokenizer,
- every_n_iters=evaluation_freq,
- evaluation_inputs=evaluation_inputs,
- evaluation_images=evaluation_images,
- system=SYSTEM,
- prompt_template=prompt_template)
-]
-
-default_hooks = dict(
- timer=dict(type=IterTimerHook),
- logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
- param_scheduler=dict(type=ParamSchedulerHook),
- checkpoint=dict(
- type=CheckpointHook,
- by_epoch=False,
- interval=save_steps,
- max_keep_ckpts=save_total_limit),
- sampler_seed=dict(type=DistSamplerSeedHook),
-)
-
-env_cfg = dict(
- cudnn_benchmark=False,
- mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
- dist_cfg=dict(backend='nccl'),
-)
-
-visualizer = None
-log_level = 'INFO'
-# load from which checkpoint
-load_from = None
-# whether to resume training from the loaded checkpoint
-resume = False
-randomness = dict(seed=None, deterministic=False)
-log_processor = dict(
- by_epoch=False,
- window_size=1,
- mean_pattern=r".*(loss|time|data_time|grad_norm|tflops).*",
-)
\ No newline at end of file
diff --git a/code/xtuner/dataset/__init__.py b/code/xtuner/dataset/__init__.py
deleted file mode 100644
index 51d06eab61156a7d1672e247c45e8087626a9e47..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/__init__.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-
-from .concat_dataset import ConcatDataset
-from .huggingface import process_hf_dataset
-from .intern_repo import (build_packed_dataset,
- load_intern_repo_tokenized_dataset,
- load_intern_repo_untokenized_dataset)
-from .json_dataset import load_json_file
-from .llava import LLaVADataset
-from .modelscope import process_ms_dataset
-from .moss_sft import MOSSSFTDataset
-from .refcoco_json import (InvRefCOCOJsonDataset, RefCOCOJsonDataset,
- RefCOCOJsonEvalDataset)
-from .utils import decode_base64_to_image, expand2square, load_image
-
-# ignore FutureWarning in hf datasets
-warnings.simplefilter(action='ignore', category=FutureWarning)
-
-
-
-# cying
-# __all__ = [
-# 'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset',
-# 'process_ms_dataset', 'LLaVADataset', 'expand2square',
-# 'decode_base64_to_image', 'load_image', 'process_ms_dataset',
-# 'load_intern_repo_tokenized_dataset',
-# 'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
-# 'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset',
-# 'load_json_file'
-# ]
-__all__ = [
- 'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset',
- 'process_ms_dataset', 'LLaVADataset', 'expand2square',
- 'decode_base64_to_image', 'load_image', 'process_ms_dataset',
- 'load_intern_repo_tokenized_dataset',
- 'load_intern_repo_untokenized_dataset', 'build_packed_dataset',
- 'RefCOCOJsonDataset', 'RefCOCOJsonEvalDataset', 'InvRefCOCOJsonDataset'
-]
-# cying
\ No newline at end of file
diff --git a/code/xtuner/dataset/__pycache__/__init__.cpython-311.pyc b/code/xtuner/dataset/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index ecc73851ca18c9a0520f47c12fc8d3bd36846726..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/concat_dataset.cpython-311.pyc b/code/xtuner/dataset/__pycache__/concat_dataset.cpython-311.pyc
deleted file mode 100644
index a75f74ecc5682d9b8b7e7ab1576d60a136346f4c..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/concat_dataset.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/huggingface.cpython-311.pyc b/code/xtuner/dataset/__pycache__/huggingface.cpython-311.pyc
deleted file mode 100644
index 261c2ebc7e76d9c47ef72898cddcb3499c8562bf..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/huggingface.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/intern_repo.cpython-311.pyc b/code/xtuner/dataset/__pycache__/intern_repo.cpython-311.pyc
deleted file mode 100644
index f7abdf39546645e540b79073fd685fabfae5d314..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/intern_repo.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/json_dataset.cpython-311.pyc b/code/xtuner/dataset/__pycache__/json_dataset.cpython-311.pyc
deleted file mode 100644
index dfaa03711a4b8571427b2c5b96722965d2767318..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/json_dataset.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/llava.cpython-311.pyc b/code/xtuner/dataset/__pycache__/llava.cpython-311.pyc
deleted file mode 100644
index 8092867bcba881f8059d8a9ee2dd0b2d5a896480..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/llava.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/modelscope.cpython-311.pyc b/code/xtuner/dataset/__pycache__/modelscope.cpython-311.pyc
deleted file mode 100644
index 3cad74c21bf50bc214b15b77bdc50211c2e85073..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/modelscope.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/moss_sft.cpython-311.pyc b/code/xtuner/dataset/__pycache__/moss_sft.cpython-311.pyc
deleted file mode 100644
index 56d5befda97b1b25f8913faa97626bf74d5c6588..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/moss_sft.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/refcoco_json.cpython-311.pyc b/code/xtuner/dataset/__pycache__/refcoco_json.cpython-311.pyc
deleted file mode 100644
index 7cb070ebec34536e8b25d48f697f77eacecca2ae..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/refcoco_json.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/__pycache__/utils.cpython-311.pyc b/code/xtuner/dataset/__pycache__/utils.cpython-311.pyc
deleted file mode 100644
index 9553e8a5c02f97f669c4d27d3277aae723d0e32a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/__pycache__/utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/collate_fns/__init__.py b/code/xtuner/dataset/collate_fns/__init__.py
deleted file mode 100644
index 96652b2599c75353faad7d54b11622f7ccee7eb3..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/collate_fns/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .default_collate_fn import default_collate_fn
-from .mmlu_collate_fn import mmlu_collate_fn
-
-__all__ = ['default_collate_fn', 'mmlu_collate_fn']
diff --git a/code/xtuner/dataset/collate_fns/__pycache__/__init__.cpython-311.pyc b/code/xtuner/dataset/collate_fns/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index b60361fe059b10812aa8fd41d93713d124c0e799..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/collate_fns/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/collate_fns/__pycache__/default_collate_fn.cpython-311.pyc b/code/xtuner/dataset/collate_fns/__pycache__/default_collate_fn.cpython-311.pyc
deleted file mode 100644
index a1e8ae02048207326b8e35e21ac27314c99955bc..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/collate_fns/__pycache__/default_collate_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/collate_fns/__pycache__/mmlu_collate_fn.cpython-311.pyc b/code/xtuner/dataset/collate_fns/__pycache__/mmlu_collate_fn.cpython-311.pyc
deleted file mode 100644
index 81bd0facfa68a2e7d554bb93389e55a01abcacd3..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/collate_fns/__pycache__/mmlu_collate_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/collate_fns/default_collate_fn.py b/code/xtuner/dataset/collate_fns/default_collate_fn.py
deleted file mode 100644
index d62c91b1e7fccde4ac337f3dcd71bb075ef41145..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/collate_fns/default_collate_fn.py
+++ /dev/null
@@ -1,226 +0,0 @@
-# # Copyright (c) OpenMMLab. All rights reserved.
-# from typing import Dict, Sequence
-# import numpy as np
-# import torch
-# from torch.nn.utils.rnn import pad_sequence
-
-# from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
-# pad_for_sequence_parallel)
-# from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
-
-
-# def default_collate_fn(instances: Sequence[Dict],
-# pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
-# return_hf_format: bool = False,
-# use_varlen_attn: bool = False):
-# seq_parallel_world_size = get_sequence_parallel_world_size()
-
-# input_ids, labels = [], []
-# has_image = any(inst.get('pixel_values') is not None for inst in instances)
-# if use_varlen_attn:
-# position_ids, cumulative_len = [], []
-# assert len(instances) == 1, (
-# f'If utilizing varlen attention, the batch size should be'
-# f' set to 1, but got {len(instances)}')
-# assert not has_image, 'Currently, it is not configured to '
-# 'accommodate the use of varlen Attention in multimodal training'
-
-# if has_image:
-# pixel_values = []
-
-# for example in instances:
-# input_ids.append(torch.LongTensor(example['input_ids']))
-# labels.append(torch.LongTensor(example['labels']))
-# if use_varlen_attn:
-# cumulative_len.append(torch.IntTensor(example['cumulative_len']))
-# position_ids.append(torch.LongTensor(example['position_ids']))
-
-# if has_image:
-# # cying
-# # pixel_values.append(example['pixel_values'])
-
-# if isinstance(example['pixel_values'], list):
-# pixel_values.extend(example['pixel_values'])
-# else:
-# pixel_values.append(example['pixel_values'])
-# # cying
-
-# ori_length = [len(ids) for ids in input_ids]
-# if len(instances) > 1:
-# input_ids = pad_sequence(
-# input_ids, batch_first=True, padding_value=pad_index)
-# labels = pad_sequence(
-# labels, batch_first=True, padding_value=IGNORE_INDEX)
-# else:
-# input_ids = torch.stack(input_ids)
-# labels = torch.stack(labels)
-
-# if use_varlen_attn:
-# assert input_ids.size(1) % seq_parallel_world_size == 0
-# attention_mask = None
-# position_ids = torch.stack(position_ids, dim=0)
-# else:
-# # Some tokenizers have the same eos token and pad token, so input_ids
-# # cannot be masked directly based on the pad token id.
-# attention_mask = torch.zeros_like(input_ids).bool()
-# for i in ori_length:
-# attention_mask[:i] = True
-
-# bs, seq_len = input_ids.shape
-# position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
-
-# if seq_parallel_world_size > 1:
-# input_ids = pad_for_sequence_parallel(input_ids, pad_index)
-# labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
-# position_ids = pad_for_sequence_parallel(position_ids, 0)
-# if attention_mask is not None:
-# attention_mask = pad_for_sequence_parallel(attention_mask, 0)
-
-# if use_varlen_attn:
-# max_seqlen = (
-# cumulative_len[0][1:] - # noqa: W504
-# cumulative_len[0][:-1]).max().item()
-# data_dict = {
-# 'input_ids': input_ids,
-# 'cumulative_len': cumulative_len,
-# 'position_ids': position_ids,
-# 'labels': labels,
-# 'max_seqlen': max_seqlen
-# }
-# else:
-# data_dict = {
-# 'input_ids': input_ids,
-# 'attention_mask': attention_mask,
-# 'position_ids': position_ids,
-# 'labels': labels
-# }
-
-# if has_image:
-
-# pixel_values = torch.stack(pixel_values)
-# data_dict['pixel_values'] = pixel_values
-
-
-# if return_hf_format:
-# return data_dict
-# else:
-# return {'data': data_dict, 'data_samples': None}
-
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Sequence
-import numpy as np
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- pad_for_sequence_parallel)
-from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
-
-
-def default_collate_fn(instances: Sequence[Dict],
- pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
- return_hf_format: bool = False,
- use_varlen_attn: bool = False):
- seq_parallel_world_size = get_sequence_parallel_world_size()
-
- input_ids, labels = [], []
- has_image = any(inst.get('pixel_values') is not None for inst in instances)
- if use_varlen_attn:
- position_ids, cumulative_len = [], []
- assert len(instances) == 1, (
- f'If utilizing varlen attention, the batch size should be'
- f' set to 1, but got {len(instances)}')
- assert not has_image, 'Currently, it is not configured to '
- 'accommodate the use of varlen Attention in multimodal training'
-
- if has_image:
- pixel_values = []
- # NEW: detect if any instance provides coords
- has_coords = any(inst.get('coords') is not None for inst in instances)
- if has_coords:
- coords = []
-
- for example in instances:
- input_ids.append(torch.LongTensor(example['input_ids']))
- labels.append(torch.LongTensor(example['labels']))
- if use_varlen_attn:
- cumulative_len.append(torch.IntTensor(example['cumulative_len']))
- position_ids.append(torch.LongTensor(example['position_ids']))
-
- if has_image:
- # Flatten any per-sample list of pixel_value tensors across the batch
- if isinstance(example['pixel_values'], list):
- pixel_values.extend(example['pixel_values'])
- else:
- pixel_values.append(example['pixel_values'])
-
- # NEW: mirror the same flattening for coords if present
- if 'coords' in example and example['coords'] is not None:
- if isinstance(example['coords'], list):
- coords.extend(example['coords'])
- else:
- coords.append(example['coords'])
-
- ori_length = [len(ids) for ids in input_ids]
- if len(instances) > 1:
- input_ids = pad_sequence(
- input_ids, batch_first=True, padding_value=pad_index)
- labels = pad_sequence(
- labels, batch_first=True, padding_value=IGNORE_INDEX)
- else:
- input_ids = torch.stack(input_ids)
- labels = torch.stack(labels)
-
- if use_varlen_attn:
- assert input_ids.size(1) % seq_parallel_world_size == 0
- attention_mask = None
- position_ids = torch.stack(position_ids, dim=0)
- else:
- # Some tokenizers have the same eos token and pad token, so input_ids
- # cannot be masked directly based on the pad token id.
- attention_mask = torch.zeros_like(input_ids).bool()
- for i in ori_length:
- attention_mask[:i] = True
-
- bs, seq_len = input_ids.shape
- position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
-
- if seq_parallel_world_size > 1:
- input_ids = pad_for_sequence_parallel(input_ids, pad_index)
- labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
- position_ids = pad_for_sequence_parallel(position_ids, 0)
- if attention_mask is not None:
- attention_mask = pad_for_sequence_parallel(attention_mask, 0)
-
- if use_varlen_attn:
- max_seqlen = (
- cumulative_len[0][1:] - # noqa: W504
- cumulative_len[0][:-1]).max().item()
- data_dict = {
- 'input_ids': input_ids,
- 'cumulative_len': cumulative_len,
- 'position_ids': position_ids,
- 'labels': labels,
- 'max_seqlen': max_seqlen
- }
- else:
- data_dict = {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask,
- 'position_ids': position_ids,
- 'labels': labels
- }
-
- if has_image:
- pixel_values = torch.stack(pixel_values)
- data_dict['pixel_values'] = pixel_values
-
- # NEW: add stacked coords if any were provided by the dataset
- if 'has_coords' in locals() and has_coords and len(coords) > 0:
- # Expect each coords tensor to be shape (N, 2) and aligned with pixel_values tensors.
- data_dict['coords'] = torch.stack(coords)
-
- if return_hf_format:
- return data_dict
- else:
- return {'data': data_dict, 'data_samples': None}
\ No newline at end of file
diff --git a/code/xtuner/dataset/collate_fns/default_collate_fn_raw.py b/code/xtuner/dataset/collate_fns/default_collate_fn_raw.py
deleted file mode 100644
index 3d9fe18fb166c5849ae9d1d658f516c4e4b0590c..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/collate_fns/default_collate_fn_raw.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Sequence
-
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- pad_for_sequence_parallel)
-from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
-
-
-def default_collate_fn(instances: Sequence[Dict],
- pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
- return_hf_format: bool = False,
- use_varlen_attn: bool = False):
- seq_parallel_world_size = get_sequence_parallel_world_size()
-
- input_ids, labels = [], []
- has_image = any(inst.get('pixel_values') is not None for inst in instances)
- if use_varlen_attn:
- position_ids, cumulative_len = [], []
- assert len(instances) == 1, (
- f'If utilizing varlen attention, the batch size should be'
- f' set to 1, but got {len(instances)}')
- assert not has_image, 'Currently, it is not configured to '
- 'accommodate the use of varlen Attention in multimodal training'
-
- if has_image:
- pixel_values = []
-
- for example in instances:
- input_ids.append(torch.LongTensor(example['input_ids']))
- labels.append(torch.LongTensor(example['labels']))
- if use_varlen_attn:
- cumulative_len.append(torch.IntTensor(example['cumulative_len']))
- position_ids.append(torch.LongTensor(example['position_ids']))
-
- if has_image:
- pixel_values.append(example['pixel_values'])
-
- ori_length = [len(ids) for ids in input_ids]
- if len(instances) > 1:
- input_ids = pad_sequence(
- input_ids, batch_first=True, padding_value=pad_index)
- labels = pad_sequence(
- labels, batch_first=True, padding_value=IGNORE_INDEX)
- else:
- input_ids = torch.stack(input_ids)
- labels = torch.stack(labels)
-
- if use_varlen_attn:
- assert input_ids.size(1) % seq_parallel_world_size == 0
- attention_mask = None
- position_ids = torch.stack(position_ids, dim=0)
- else:
- # Some tokenizers have the same eos token and pad token, so input_ids
- # cannot be masked directly based on the pad token id.
- attention_mask = torch.zeros_like(input_ids).bool()
- for i, length in enumerate(ori_length):
- attention_mask[i, :length] = True
-
- bs, seq_len = input_ids.shape
- position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
-
- if seq_parallel_world_size > 1:
- input_ids = pad_for_sequence_parallel(input_ids, pad_index)
- labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
- position_ids = pad_for_sequence_parallel(position_ids, 0)
- if attention_mask is not None:
- attention_mask = pad_for_sequence_parallel(attention_mask, 0)
-
- if use_varlen_attn:
- max_seqlen = (
- cumulative_len[0][1:] - # noqa: W504
- cumulative_len[0][:-1]).max().item()
- data_dict = {
- 'input_ids': input_ids,
- 'cumulative_len': cumulative_len,
- 'position_ids': position_ids,
- 'labels': labels,
- 'max_seqlen': max_seqlen
- }
- else:
- data_dict = {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask,
- 'position_ids': position_ids,
- 'labels': labels
- }
-
- if has_image:
- if all(x.shape == pixel_values[0].shape for x in pixel_values):
- pixel_values = torch.stack(pixel_values, dim=0)
- data_dict['pixel_values'] = pixel_values
-
- if return_hf_format:
- return data_dict
- else:
- return {'data': data_dict, 'data_samples': None}
diff --git a/code/xtuner/dataset/collate_fns/mmlu_collate_fn.py b/code/xtuner/dataset/collate_fns/mmlu_collate_fn.py
deleted file mode 100644
index 5c0e2a9894f897cbe7ed80680b15b364e767a33c..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/collate_fns/mmlu_collate_fn.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Sequence
-
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
-
-
-def mmlu_collate_fn(instances: Sequence[Dict],
- pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
- return_hf_format: bool = False) -> Dict[str, torch.Tensor]:
- input_ids = []
- labels = []
- data_samples = {'labels': [], 'subjects': []}
- for example in instances:
- input_ids.append(torch.tensor(example['input_ids']))
- labels.append(torch.tensor(example['labels']))
- data_samples['labels'].append(example['output'])
- data_samples['subjects'].append(example['subject'])
- if len(instances) > 1:
- input_ids = pad_sequence(
- input_ids, batch_first=True, padding_value=pad_index)
- labels = pad_sequence(
- labels, batch_first=True, padding_value=IGNORE_INDEX)
- else:
- input_ids = torch.stack(input_ids)
- labels = torch.stack(labels)
-
- data_dict = {
- 'input_ids': input_ids,
- 'attention_mask': input_ids.ne(pad_index),
- 'labels': labels
- }
-
- if return_hf_format:
- return data_dict
- else:
- return {'data': data_dict, 'data_samples': data_samples}
diff --git a/code/xtuner/dataset/collate_fns/preference_collate_fn.py b/code/xtuner/dataset/collate_fns/preference_collate_fn.py
deleted file mode 100644
index 4b6a7f5c3eacdeb97b402ad340f3d67a6d7fbccb..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/collate_fns/preference_collate_fn.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Sequence
-
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- pad_cumulative_len_for_sequence_parallel,
- pad_for_sequence_parallel)
-from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
-
-
-def preference_collate_fn(instances: Sequence[Dict],
- pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
- return_hf_format: bool = False,
- use_varlen_attn: bool = False):
- seq_parallel_world_size = get_sequence_parallel_world_size()
- ds_names = []
- if not use_varlen_attn:
- # split chosen and rejected into two instances
- splited_instances = []
- for d in instances:
- splited_instances.append({
- 'input_ids': d['chosen_ids'],
- 'labels': d['chosen_labels']
- })
- splited_instances.append({
- 'input_ids': d['rejected_ids'],
- 'labels': d['rejected_labels']
- })
- ds_names.append(d.get('ds_name', None))
- instances = splited_instances
-
- input_ids, labels = [], []
- if use_varlen_attn:
- position_ids, cumulative_len = [], []
- assert len(instances) == 1, (
- f'If utilizing varlen attention, the batch size should be'
- f' set to 1, but got {len(instances)}')
-
- for example in instances:
- input_ids.append(torch.LongTensor(example['input_ids']))
- labels.append(torch.LongTensor(example['labels']))
- if use_varlen_attn:
- cumulative_len.append(torch.IntTensor(example['cumulative_len']))
- position_ids.append(torch.LongTensor(example['position_ids']))
- num_samples = (len(example['cumulative_len']) - 1) // 2
- ds_names.extend(example.get('ds_names', [None] * num_samples))
-
- ori_length = [len(ids) for ids in input_ids]
- if len(instances) > 1:
- input_ids = pad_sequence(
- input_ids, batch_first=True, padding_value=pad_index)
- labels = pad_sequence(
- labels, batch_first=True, padding_value=IGNORE_INDEX)
- else:
- input_ids = torch.stack(input_ids)
- labels = torch.stack(labels)
-
- if use_varlen_attn:
- attention_mask = None
- position_ids = torch.stack(position_ids, dim=0)
- else:
- # Some tokenizers have the same eos token and pad token, so input_ids
- # cannot be masked directly based on the pad token id.
- attention_mask = torch.zeros_like(input_ids).bool()
- for i, length in enumerate(ori_length):
- attention_mask[i, :length] = True
-
- bs, seq_len = input_ids.shape
- position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
-
- if seq_parallel_world_size > 1:
- input_ids = pad_for_sequence_parallel(input_ids, pad_index)
- labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
- position_ids = pad_for_sequence_parallel(position_ids, 0)
- if attention_mask is not None:
- attention_mask = pad_for_sequence_parallel(attention_mask, 0)
- if use_varlen_attn:
- # We use attention_mask to distinguish `input_ids` from
- # (sequence parallel) pad tokens in `get_var_len_atten_logps`
- # method of class `DPO` and `ORPO`
- (cumulative_len, attention_mask
- ) = pad_cumulative_len_for_sequence_parallel(cumulative_len)
-
- if use_varlen_attn:
- max_seqlen = (
- cumulative_len[0][1:] - # noqa: W504
- cumulative_len[0][:-1]).max().item()
- data_dict = {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask,
- 'cumulative_len': cumulative_len,
- 'position_ids': position_ids,
- 'labels': labels,
- 'max_seqlen': max_seqlen
- }
- else:
- data_dict = {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask,
- 'position_ids': position_ids,
- 'labels': labels
- }
-
- if return_hf_format:
- return data_dict
- else:
- return {'data': data_dict, 'data_samples': {'ds_names': ds_names}}
diff --git a/code/xtuner/dataset/concat_dataset.py b/code/xtuner/dataset/concat_dataset.py
deleted file mode 100644
index 18d0a4c2f1d68755768132aa97d6852ac7b311e1..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/concat_dataset.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from torch.utils.data import ConcatDataset as TorchConcatDataset
-
-from xtuner.registry import BUILDER
-
-
-class ConcatDataset(TorchConcatDataset):
-
- def __init__(self, datasets):
- datasets_instance = []
- for cfg in datasets:
- datasets_instance.append(BUILDER.build(cfg))
- super().__init__(datasets=datasets_instance)
-
- def __repr__(self):
- main_str = 'Dataset as a concatenation of multiple datasets. \n'
- main_str += ',\n'.join(
- [f'{repr(dataset)}' for dataset in self.datasets])
- return main_str
diff --git a/code/xtuner/dataset/huggingface.py b/code/xtuner/dataset/huggingface.py
deleted file mode 100644
index c9ef68de888471f4f6201a1f21ae85e7a7331195..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/huggingface.py
+++ /dev/null
@@ -1,346 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-import os
-from datetime import timedelta
-from functools import partial
-
-import numpy as np
-from datasets import DatasetDict, concatenate_datasets
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.utils.misc import get_object_from_string
-from torch import distributed as dist
-
-from xtuner.registry import BUILDER, MAP_FUNC
-from .utils import Packer, encode_fn
-
-
-def get_lengths(example):
- return {'length': len(example['input_ids'])}
-
-
-def build_origin_dataset(dataset, split):
- if isinstance(dataset, DatasetDict):
- if split is None:
- dataset = concatenate_datasets(dataset.values())
- else:
- dataset = dataset[split]
- elif isinstance(dataset, dict) or isinstance(
- dataset, Config) or isinstance(dataset, ConfigDict):
- dataset = BUILDER.build(dataset)
- if isinstance(dataset, DatasetDict):
- if split is None:
- dataset = concatenate_datasets(dataset.values())
- else:
- dataset = dataset[split]
- return dataset
-
-
-def map_dataset(dataset, dataset_map_fn, map_num_proc):
- if isinstance(dataset_map_fn, str):
- map_fn_obj = MAP_FUNC.get(dataset_map_fn) or get_object_from_string(
- dataset_map_fn)
- if map_fn_obj is not None:
- dataset_map_fn = map_fn_obj
- else:
- raise TypeError('dataset_map_fn must be a function or a '
- "registered function's string in MAP_FUNC, "
- f"but got a string of '{dataset_map_fn}'")
-
- dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc)
- return dataset
-
-
-def add_template_to_dataset(dataset, template_map_fn, map_num_proc):
- if isinstance(template_map_fn,
- dict) or isinstance(template_map_fn, Config) or isinstance(
- template_map_fn, ConfigDict):
- template_map_fn = BUILDER.build(template_map_fn)
- dataset = dataset.map(template_map_fn, num_proc=map_num_proc)
- # remove invalid data
- dataset = dataset.filter(
- lambda example: len(example['conversation']) > 0,
- num_proc=map_num_proc)
- return dataset
-
-
-def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
- per_image_length,
- input_ids_with_output, remove_unused_columns,
- map_num_proc):
- assert (tokenizer is not None) and (max_length is not None), \
- f'({tokenizer}, {max_length})'
- if isinstance(tokenizer, dict) or isinstance(
- tokenizer, Config) or isinstance(tokenizer, ConfigDict):
- tokenizer = BUILDER.build(tokenizer)
- dataset = dataset.map(
- partial(
- encode_fn,
- tokenizer=tokenizer,
- max_length=max_length,
- with_image_token=with_image_token,
- per_image_length=per_image_length,
- input_ids_with_output=input_ids_with_output),
- remove_columns=list(dataset.column_names)
- if remove_unused_columns else None,
- num_proc=map_num_proc)
- return dataset
-
-
-def pack_dataset(dataset, max_length, use_varlen_attn, shuffle_before_pack,
- map_num_proc):
- if shuffle_before_pack:
- dataset = dataset.shuffle()
- dataset = dataset.flatten_indices(num_proc=map_num_proc)
- dataset = dataset.map(
- Packer(max_length, use_varlen_attn=use_varlen_attn),
- batched=True,
- num_proc=map_num_proc)
- return dataset
-
-
-def process(dataset,
- do_dataset_tokenization=True,
- tokenizer=None,
- max_length=None,
- dataset_map_fn=None,
- template_map_fn=None,
- max_dataset_length=None,
- split='train',
- remove_unused_columns=False,
- rename_maps=[],
- shuffle_before_pack=True,
- pack_to_max_length=True,
- use_varlen_attn=False,
- input_ids_with_output=True,
- with_image_token=False,
- per_image_length=None,
- map_num_proc=32):
- """Post-process the dataset loaded from the Hugging Face Hub, or a local
- dataset.
-
- Args:
- dataset: The dataset to be post-processed.
- do_dataset_tokenization: Whether the dataset need to be tokenized
- in this function. Default to True.
- tokenizer: The tokenizer processes some raw text as input and outputs
- an Encoding. If `do_dataset_tokenization` is True, this argument
- should not be None. Default to None.
- max_length: Max length of the sequence. If `do_dataset_tokenization`
- or `pack_to_max_length` is True, this argument should not be None.
- Default to None.
- dataset_map_fn: Map the original dataset format to the one defined
- by xTuner.
- template_map_fn: Add the prompt template to the dataset
- max_dataset_length: If the length of the dataset is too long, we can
- randomly extract `max_dataset_length` from it.
- split: Which split of the data to load.
- If `None`, will return a single concatenated dataset with all
- splits (typically `datasets.Split.TRAIN` and
- `datasets.Split.TEST`).
- If given, will return a single Dataset.
- remove_unused_columns: Whether to remove columns from the dataset
- that are not used during training.
- rename_maps: Rename the column name of the dataset.
- shuffle_before_pack: Whether to shuffle the dataset before
- packing them.
- pack_to_max_length: Whether to pack the dataset to the `max_length `.
- This usually improves gpu utilization and therefore reduces
- training time.
- use_varlen_attn: If use_varlen_attn is True, we calculate attention
- the actual length of the sequence rather than the actual length
- of the sequence
- input_ids_with_output: Whether to put the groundtruth output
- corresponding to the question into the dataset. Typically set
- it to True during training and False during testing.
- with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
- IMAGE_TOKEN_INDEX. Typically set it to True during the training
- of VLM.
- map_num_proc: Max number of processes when mapping the dataset.
- """
- if use_varlen_attn:
- assert pack_to_max_length, \
- '`pack_to_max_length` in `process_hf_dataset` should be set to ' \
- 'True if `use_varlen_attn` is True.'
- if pack_to_max_length:
- assert split == 'train' or split is None, \
- ('`split` should be `train` or `None` if `pack_to_max_length` is '
- f'True, but got {split}.')
-
- # dataset = build_origin_dataset(dataset, split)
- dataset = build_origin_dataset(dataset=dataset, split=split)
-
- # sample `max_dataset_length` items from the original dataset to
- # save time consumed by map function
- if max_dataset_length is not None:
- max_dataset_length = min(max_dataset_length, len(dataset))
- indices = np.random.choice(
- len(dataset), max_dataset_length, replace=False)
- dataset = dataset.select(indices)
-
- # Extract the useful data for training from the original dataset.
- if dataset_map_fn is not None:
- # dataset = map_dataset(dataset, dataset_map_fn, map_num_proc)
- dataset = map_dataset(
- dataset=dataset,
- dataset_map_fn=dataset_map_fn,
- map_num_proc=map_num_proc)
-
- # Add prompt template, such as <|System|>: xxx <|User|>: xxx <|Bot|>: xxx
- if template_map_fn is not None:
- # dataset = add_template_to_dataset(dataset, template_map_fn,
- # map_num_proc)
- dataset = add_template_to_dataset(
- dataset=dataset,
- template_map_fn=template_map_fn,
- map_num_proc=map_num_proc)
-
-
- for old, new in rename_maps:
- dataset = dataset.rename_column(old, new)
-
- # remove unused columns
- if pack_to_max_length and (not remove_unused_columns):
- print_log(
- 'We have to remove unused columns if '
- '`pack_to_max_length` is set to True.',
- logger='current',
- level=logging.WARNING)
- remove_unused_columns = True
-
- if do_dataset_tokenization:
- # dataset = tokenize_dataset(dataset, tokenizer, max_length,
- # with_image_token, input_ids_with_output,
- # remove_unused_columns, map_num_proc)
- dataset = tokenize_dataset(
- dataset=dataset,
- tokenizer=tokenizer,
- max_length=max_length,
- with_image_token=with_image_token,
- per_image_length=per_image_length,
- input_ids_with_output=input_ids_with_output,
- remove_unused_columns=remove_unused_columns,
- map_num_proc=map_num_proc)
-
-
- if input_ids_with_output:
- assert {'input_ids', 'labels'}.issubset(dataset.column_names)
- # remove data that does not have the valid labels.
- dataset = dataset.filter(
- lambda example: any(label >= 0 for label in example['labels']),
- num_proc=map_num_proc)
-
- # pack to max length
- if pack_to_max_length:
- # dataset = pack_dataset(dataset, max_length, use_varlen_attn,
- # shuffle_before_pack, map_num_proc)
- dataset = pack_dataset(
- dataset=dataset,
- max_length=max_length,
- use_varlen_attn=use_varlen_attn,
- shuffle_before_pack=shuffle_before_pack,
- map_num_proc=map_num_proc)
-
- # add 'length'
- dataset = dataset.map(get_lengths, num_proc=map_num_proc)
- setattr(dataset, 'length', dataset['length'])
-
- return dataset
-
-
-def process_hf_dataset(dataset,
- do_dataset_tokenization=True,
- tokenizer=None,
- max_length=None,
- dataset_map_fn=None,
- template_map_fn=None,
- max_dataset_length=None,
- split='train',
- remove_unused_columns=False,
- rename_maps=[],
- shuffle_before_pack=True,
- pack_to_max_length=True,
- use_varlen_attn=False,
- input_ids_with_output=True,
- with_image_token=False,
- per_image_length=None,
- map_num_proc=32):
- """Post-process the dataset loaded from the Hugging Face Hub, or a local
- dataset.
-
- Args:
- dataset: The dataset to be post-processed.
- do_dataset_tokenization: Whether the dataset need to be tokenized
- in this function. Default to True.
- tokenizer: The tokenizer processes some raw text as input and outputs
- an Encoding. If `do_dataset_tokenization` is True, this argument
- should not be None. Default to None.
- max_length: Max length of the sequence. If `do_dataset_tokenization`
- or `pack_to_max_length` is True, this argument should not be None.
- Default to None.
- dataset_map_fn: Map the original dataset format to the one defined
- by xTuner.
- template_map_fn: Add the prompt template to the dataset
- max_dataset_length: If the length of the dataset is too long, we can
- randomly extract `max_dataset_length` from it.
- split: Which split of the data to load.
- If `None`, will return a single concatenated dataset with all
- splits (typically `datasets.Split.TRAIN` and
- `datasets.Split.TEST`).
- If given, will return a single Dataset.
- remove_unused_columns: Whether to remove columns from the dataset
- that are not used during training.
- rename_maps: Rename the column name of the dataset.
- shuffle_before_pack: Whether to shuffle the dataset before
- packing them.
- pack_to_max_length: Whether to pack the dataset to the `max_length `.
- This usually improves gpu utilization and therefore reduces
- training time.
- use_varlen_attn: If use_varlen_attn is True, we calculate attention
- the actual length of the sequence rather than the actual length
- of the sequence
- input_ids_with_output: Whether to put the groundtruth output
- corresponding to the question into the dataset. Typically set
- it to True during training and False during testing.
- with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
- IMAGE_TOKEN_INDEX. Typically set it to True during the training
- of VLM.
- map_num_proc: Max number of processes when mapping the dataset.
- """
- kwargs = dict(
- dataset=dataset,
- do_dataset_tokenization=do_dataset_tokenization,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=dataset_map_fn,
- template_map_fn=template_map_fn,
- max_dataset_length=max_dataset_length,
- split=split,
- remove_unused_columns=remove_unused_columns,
- rename_maps=rename_maps,
- shuffle_before_pack=shuffle_before_pack,
- pack_to_max_length=pack_to_max_length,
- use_varlen_attn=use_varlen_attn,
- input_ids_with_output=input_ids_with_output,
- with_image_token=with_image_token,
- per_image_length=per_image_length,
- map_num_proc=map_num_proc)
- if not (dist.is_available() and dist.is_initialized()):
- return process(**kwargs)
-
- xtuner_dataset_timeout = timedelta(
- minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=60)))
- print_log(
- f'xtuner_dataset_timeout = {xtuner_dataset_timeout}', logger='current')
- # monitored barrier requires gloo process group to perform host-side sync.
- group_gloo = dist.new_group(backend='gloo', timeout=xtuner_dataset_timeout)
-
- if dist.get_rank() == 0:
- dataset = process(**kwargs)
- objects = [dataset]
- else:
- objects = [None]
-
- dist.monitored_barrier(group=group_gloo, timeout=xtuner_dataset_timeout)
- dist.broadcast_object_list(objects, src=0)
- return objects[0]
diff --git a/code/xtuner/dataset/intern_repo.py b/code/xtuner/dataset/intern_repo.py
deleted file mode 100644
index 95cd7cf99ad65da9880ae54235e7791cb6016fd5..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/intern_repo.py
+++ /dev/null
@@ -1,362 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import itertools as it
-import json
-import mmap
-import operator
-import os
-import threading
-from pathlib import Path
-
-import numpy as np
-import torch
-from datasets import Dataset, load_dataset, load_from_disk
-from mmengine import print_log
-from torch import distributed as dist
-from torch.utils.data import ConcatDataset
-
-from xtuner.dataset.map_fns import openai_map_fn
-from xtuner.registry import BUILDER
-from .huggingface import process
-
-
-class JsonlDataset(torch.utils.data.Dataset):
- """
-
- JSONL format is expected to roughly follow that of The Pile.
- One-line-per-document of the form:
- ```
- {
- "input_ids": List[int],
- "labels": List[int]
- }
- ```
-
- """
-
- def __init__(self, path: str, min_length=50):
- self.path = path
- self.threadlocal = threading.local()
- resolved_path = Path(path).resolve()
- self.resolved_path = resolved_path
- self.meta = Path(f'{resolved_path}.meta')
-
- # only build the cache in on the primary worker to prevent
- # overloading nfs
- assert os.path.exists(
- self.meta
- ), f'The cache file:{self.meta} is not found for file:{self.path}'
- try:
- with open(self.meta, 'rb') as f:
- meta = np.load(f)
- except Exception as e:
- print(f'Cannot load file {self.meta}...')
- raise e
- self.offsets = meta[:, 0]
- self.length = meta[:, -1]
-
- if min_length > 0:
- mask = self.length >= min_length
- self.offsets = self.offsets[mask]
- self.length = self.length[mask]
-
- def __getitem__(self, idx):
- f = self._get_mmap()
- position = self.offsets[idx]
- f.seek(position)
- item = f.readline().decode('utf-8')
- try:
- item = json.loads(item)
- item['input_ids'] = item['tokens']
- del item['tokens']
- labels = [x if x > 0 else -100 for x in item['input_ids']]
- item['input_ids'] = [abs(x) for x in item['input_ids']]
- item['labels'] = labels
- item['length'] = len(item['input_ids']) # add a length info
- except Exception as err:
- raise json.decoder.JSONDecodeError(
- doc=self.path,
- pos=position,
- msg=(f'Error while loading JSONL line in file {self.path} '
- f'at byte {position}. Contents of line:\n{item}\n{err}'),
- )
- return item
-
- def get_dataset_name(self):
- return str(self.resolved_path)
-
- def _get_mmap(self):
- if not hasattr(self.threadlocal, 'handles'):
- with open(self.path, 'rb') as f:
- mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
- self.threadlocal.handles = [f, mm]
- if self.path.endswith('.gz') or self.path.endswith(
- '.bz') or self.path.endswith('.bz2'):
- raise NotImplementedError(
- 'Compressed files are not supported because .seek() '
- 'would require rereading the entire file, making '
- 'performance too slow.')
- return self.threadlocal.handles[-1]
-
- def __setstate__(self, state):
- self.__dict__ = state
- self.threadlocal = threading.local()
-
- def __getstate__(self):
- d = {}
- for i, v in self.__dict__.items():
- if i != 'threadlocal':
- d[i] = v
- return d
-
- def __del__(self):
- if hasattr(self.threadlocal, 'handles'):
- # cleanup files we opened on initialization
- while self.threadlocal.handles:
- self.threadlocal.handles.pop().close()
-
- @staticmethod
- def exists(path):
- return os.path.exists(path)
-
- def __len__(self):
- # Virtual length of the dataset depends on the epoch number
- # if the number of documents is not perfectly divisible by the
- # data_subshard_count
- return len(self.offsets)
-
-
-class PackedDataset(torch.utils.data.Dataset):
- """The class PackedDataset takes in a dataset and aggregates samples of
- different lengths together based on the packed_length.
-
- Args:
- dataset: The original dataset to pack.
- packed_length: The length of each packed sample. Default is 8192.
- """
-
- def __init__(self, dataset, packed_length: int = 8192, seed: int = 1024):
- self.dataset = dataset
- self.packed_length = packed_length
- if isinstance(dataset, JsonlDataset):
- self.length = dataset.length
- elif isinstance(dataset, Dataset):
- if hasattr(dataset, 'length'):
- length = dataset.length
- else:
- length = [len(i['input_ids']) for i in dataset]
- self.length = length
- else:
- raise NotImplementedError
- self.seed = seed
-
- rng = np.random.RandomState(self.seed)
- shuffled_indices = np.arange(len(self.length))
- rng.shuffle(shuffled_indices)
- self.shuffled_indices = shuffled_indices.tolist()
- self.shuffled_samples_len = list(
- map(self.length.__getitem__, shuffled_indices))
- self.shuffled_accumulated_samples_len = list(
- it.accumulate(self.shuffled_samples_len, operator.add))
- self.num_tokens = sum(self.length)
-
- def __len__(self):
- return self.num_tokens // self.packed_length
-
- def search_sample_index(self, pack_idx: int = 0):
- assert pack_idx >= 0
- length_train = (pack_idx + 1) * self.packed_length
- sample_index = np.searchsorted(
- self.shuffled_accumulated_samples_len, length_train, side='left')
- return sample_index
-
- def mapping(self, pack_idx: int = 0):
- begin_sample_idx, begin_token_id = 0, 0
- if pack_idx > 0:
- begin_sample_idx = self.search_sample_index(pack_idx - 1)
- # The position where the previous packed data ends
- begin_token_id = self.shuffled_samples_len[begin_sample_idx] - (
- self.shuffled_accumulated_samples_len[begin_sample_idx]
- - # noqa: W504,W503
- (pack_idx) * self.packed_length)
- if begin_token_id == self.shuffled_samples_len[begin_sample_idx]:
- begin_sample_idx += 1
- begin_token_id = 0
-
- end_sample_idx = self.search_sample_index(pack_idx)
- end_token_id = self.shuffled_samples_len[end_sample_idx] - (
- self.shuffled_accumulated_samples_len[end_sample_idx]
- - # noqa: W504,W503
- (pack_idx + 1) * self.packed_length)
- return begin_sample_idx, begin_token_id, end_sample_idx, end_token_id
-
- def build_pack(self, begin_sample_idx: int, begin_token_id: int,
- end_sample_idx: int, end_token_id: int):
- pack, cumulative_len, position_ids, labels = [], [0], [], []
-
- while begin_sample_idx < end_sample_idx:
- sample_idx = self.shuffled_indices[begin_sample_idx]
- sample = self.dataset[sample_idx]
- chunk = sample['input_ids'][begin_token_id:]
- pack.extend(chunk)
- _labels = sample['labels'][begin_token_id:]
- assert len(_labels) == len(chunk), (_labels, chunk)
- labels.extend(_labels)
- cumulative_len.append(cumulative_len[-1] + len(chunk))
- position_ids.extend(list(range(len(chunk))))
- begin_sample_idx = begin_sample_idx + 1
- begin_token_id = 0
-
- sample_idx = self.shuffled_indices[end_sample_idx]
- sample = self.dataset[sample_idx]
- chunk = sample['input_ids'][begin_token_id:
- end_token_id] # fragment of a sample
- _labels = sample['labels'][begin_token_id:end_token_id]
- pack.extend(chunk)
- assert len(_labels) == len(chunk), (_labels, chunk)
- labels.extend(_labels)
- cumulative_len.append(cumulative_len[-1] + len(chunk))
- position_ids.extend(list(range(len(chunk))))
-
- out = {
- 'input_ids': pack,
- 'cumulative_len': cumulative_len,
- 'position_ids': position_ids,
- 'labels': labels
- }
- return out
-
- def __getitem__(self, item: int):
- pos_before, token_id_before, pos_after, token_id_after = self.mapping(
- item)
- return self.build_pack(pos_before, token_id_before, pos_after,
- token_id_after)
-
-
-def load_intern_repo_tokenized_dataset(folder,
- min_length=0,
- data_order_path=None,
- file_type='.bin'):
- assert os.path.exists(folder), f'{folder} does not exist.'
- datasets = []
-
- if data_order_path is not None:
- data_order = load_dataset(
- 'text', data_files=data_order_path, split='train')['text']
- for i, fp in enumerate(data_order):
- data_order[i] = os.path.join(folder, fp)
- else:
- triples = list(os.walk(folder, followlinks=True))
- data_order = []
- for root, dirs, files in triples:
- dirs.sort()
- for fn in sorted(files):
- if fn.endswith(file_type):
- fp = os.path.join(root, fn)
- data_order.append(fp)
-
- for fp in data_order:
- print_log(f'Reading {fp}...', logger='current')
- ds = JsonlDataset(fp, min_length=min_length)
-
- if len(ds) == 0:
- continue
- datasets.append(ds)
-
- return datasets
-
-
-def load_intern_repo_untokenized_dataset(processed_dataset_dict_path=None,
- folder=None,
- tokenizer=None,
- max_length=None,
- template_map_fn=None,
- data_order_path=None,
- file_type='.json'):
-
- assert processed_dataset_dict_path or (folder and tokenizer and max_length)
-
- if processed_dataset_dict_path is not None:
- ds = load_from_disk(processed_dataset_dict_path)
- datasets = []
- for key, data in ds.items():
- datasets.append((key, data))
- datasets = sorted(datasets, key=lambda x: int(x[0]))
- datasets = [x[1] for x in datasets]
- return datasets
-
- assert os.path.exists(folder), f'{folder} does not exist.'
- datasets = []
-
- if data_order_path is not None:
- data_order = load_dataset(
- 'text', data_files=data_order_path, split='train')['text']
- for i, fp in enumerate(data_order):
- data_order[i] = os.path.join(folder, fp)
- else:
- triples = list(os.walk(folder, followlinks=True))
- data_order = []
- for root, dirs, files in triples:
- dirs.sort()
- for fn in sorted(files):
- if fn.endswith(file_type):
- fp = os.path.join(root, fn)
- data_order.append(fp)
-
- for fp in data_order:
- print_log(f'Reading {fp}...', logger='current')
- dataset = []
- with open(fp) as file:
- lines = file.readlines()
- for line in lines:
- line = json.loads(line)
- dataset.append({'messages': line})
- dataset = Dataset.from_list(dataset)
- dataset = process(
- dataset,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=openai_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=True,
- pack_to_max_length=False,
- map_num_proc=32)
-
- if len(dataset) == 0:
- continue
-
- datasets.append(dataset)
-
- return datasets
-
-
-def build_packed_dataset_rank0(dataset_cfg, packed_length=8192, seed=1024):
- if isinstance(dataset_cfg, dict):
- datasets = BUILDER.build(dataset_cfg)
- else:
- datasets = dataset_cfg
-
- if not isinstance(datasets, list):
- datasets = [datasets]
-
- packed_datasets = []
-
- for dataset in datasets:
- ds = PackedDataset(dataset, packed_length, seed=seed)
- packed_datasets.append(ds)
-
- dataset = ConcatDataset(datasets=packed_datasets)
-
- return dataset
-
-
-def build_packed_dataset(*args, **kwargs):
- if not (dist.is_available() and dist.is_initialized()):
- return build_packed_dataset_rank0(*args, **kwargs)
-
- if dist.get_rank() == 0:
- dataset = build_packed_dataset_rank0(*args, **kwargs)
- objects = [dataset]
- else:
- objects = [None]
- dist.broadcast_object_list(objects, src=0)
- return objects[0]
diff --git a/code/xtuner/dataset/internvl_dataset.py b/code/xtuner/dataset/internvl_dataset.py
deleted file mode 100644
index 82904ae8777bd8a6eab9f9fc3b4ed929b6d350ce..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/internvl_dataset.py
+++ /dev/null
@@ -1,409 +0,0 @@
-import copy
-import io
-import json
-import os
-import random
-import warnings
-
-import numpy as np
-import torch
-import torchvision.transforms as T
-from mmengine import print_log
-from mmengine.fileio import get
-from PIL import Image
-from torch.utils.data import Dataset
-from torchvision.transforms.functional import InterpolationMode
-from transformers import AutoConfig, AutoTokenizer
-
-from xtuner.utils import IGNORE_INDEX
-
-
-# Referenced from InternVL
-def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
- image_size):
- best_ratio_diff = float('inf')
- best_ratio = (1, 1)
- area = width * height
- for ratio in target_ratios:
- target_aspect_ratio = ratio[0] / ratio[1]
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
- if ratio_diff < best_ratio_diff:
- best_ratio_diff = ratio_diff
- best_ratio = ratio
- elif ratio_diff == best_ratio_diff:
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
- best_ratio = ratio
- return best_ratio
-
-
-def dynamic_preprocess(image,
- min_num=1,
- max_num=6,
- image_size=448,
- use_thumbnail=False):
- orig_width, orig_height = image.size
- aspect_ratio = orig_width / orig_height
-
- # calculate the existing image aspect ratio
- target_ratios = {(i, j)
- for n in range(min_num, max_num + 1)
- for i in range(1, n + 1) for j in range(1, n + 1)
- if i * j <= max_num and i * j >= min_num}
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- # find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
- target_ratios, orig_width,
- orig_height, image_size)
-
- # calculate the target width and height
- target_width = image_size * target_aspect_ratio[0]
- target_height = image_size * target_aspect_ratio[1]
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
-
- # resize the image
- resized_img = image.resize((target_width, target_height))
- processed_images = []
- for i in range(blocks):
- box = ((i % (target_width // image_size)) * image_size,
- (i // (target_width // image_size)) * image_size,
- ((i % (target_width // image_size)) + 1) * image_size,
- ((i // (target_width // image_size)) + 1) * image_size)
- # split the image
- split_img = resized_img.crop(box)
- processed_images.append(split_img)
- assert len(processed_images) == blocks
- if use_thumbnail and len(processed_images) != 1:
- thumbnail_img = image.resize((image_size, image_size))
- processed_images.append(thumbnail_img)
- return processed_images
-
-
-def total_image_token(orig_size,
- min_num=1,
- max_num=12,
- image_size=448,
- use_thumbnail=True):
- orig_width, orig_height = orig_size
-
- aspect_ratio = orig_width / orig_height
-
- # calculate the existing image aspect ratio
- target_ratios = {(i, j)
- for n in range(min_num, max_num + 1)
- for i in range(1, n + 1) for j in range(1, n + 1)
- if max_num >= i * j >= min_num}
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
-
- # find the closest aspect ratio to the target
- target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
- target_ratios, orig_width,
- orig_height, image_size)
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
-
- if use_thumbnail:
- blocks += 1
-
- return blocks
-
-
-def load_json_or_jsonl(json_path):
- if json_path.endswith('.json'):
- with open(json_path) as f:
- data = json.load(f)
- elif json_path.endswith('.jsonl'):
- with open(json_path) as f:
- data = [json.loads(line) for line in f]
- else:
- raise ValueError(f'Unsupported file format: {json_path}, '
- f'only support .json and .jsonl.')
- return data
-
-
-class InternVL_V1_5_Dataset(Dataset):
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
- IMG_CONTEXT_TOKEN = ''
- IMG_START_TOKEN = '
'
- IMG_END_TOKEN = ''
-
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
- IMAGENET_STD = (0.229, 0.224, 0.225)
-
- def __init__(self,
- model_path,
- template,
- data_paths,
- image_folders=None,
- repeat_times=1,
- max_length=8192):
- self.template = template
- self.max_length = max_length
-
- self.cfg = AutoConfig.from_pretrained(
- model_path, trust_remote_code=True)
-
- # The following modifications are only to ensure full
- # consistency with the official template,
- # without investigating the impact on performance.
- if self.cfg.llm_config.architectures[0] == 'Phi3ForCausalLM':
- self._system = 'You are an AI assistant whose name is Phi-3.'
- self.template[
- 'INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n'
- elif self.cfg.llm_config.architectures[0] == 'InternLM2ForCausalLM':
- self._system = 'You are an AI assistant whose name ' \
- 'is InternLM (书生·浦语).'
- self.template['SYSTEM'] = '<|im_start|>system\n{system}<|im_end|>'
- self.template[
- 'INSTRUCTION'] = '<|im_start|>user\n{input}' \
- '<|im_end|><|im_start|>assistant\n'
- else:
- raise NotImplementedError
-
- self.min_dynamic_patch = self.cfg.min_dynamic_patch
- self.max_dynamic_patch = self.cfg.max_dynamic_patch
- self.downsample_ratio = self.cfg.downsample_ratio
- self.image_size = self.cfg.force_image_size
- self.use_thumbnail = self.cfg.use_thumbnail
- patch_size = self.cfg.vision_config.patch_size
- self.patch_token = int(
- (self.image_size // patch_size)**2 * (self.downsample_ratio**2))
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_path, trust_remote_code=True)
- self.transformer = T.Compose([
- T.Lambda(lambda img: img.convert('RGB')
- if img.mode != 'RGB' else img),
- T.Resize((self.image_size, self.image_size),
- interpolation=InterpolationMode.BICUBIC),
- T.ToTensor(),
- T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
- ])
-
- if not isinstance(data_paths, (list, tuple)):
- data_paths = [data_paths]
- if not isinstance(image_folders, (list, tuple)):
- image_folders = [image_folders]
- if not isinstance(repeat_times, (list, tuple)):
- repeat_times = [repeat_times]
- assert len(data_paths) == len(image_folders) == len(repeat_times)
-
- print_log('Starting to loading data and calc length', logger='current')
- self.data = []
- self.image_folder = []
- self.group_length = []
- self.conv2length_text = {
- } # using dict to speedup the calculation of token length
-
- for data_file, image_folder, repeat_time in zip(
- data_paths, image_folders, repeat_times):
- print_log(
- f'=======Starting to process {data_file} =======',
- logger='current')
- assert repeat_time > 0
- json_data = load_json_or_jsonl(data_file)
- if repeat_time < 1:
- json_data = random.sample(json_data,
- int(len(json_data) * repeat_time))
- elif repeat_time > 1:
- int_repeat_time = int(repeat_time)
- remaining_repeat_time = repeat_time - repeat_time
- if remaining_repeat_time > 0:
- remaining_json_data = random.sample(
- json_data, int(len(json_data) * remaining_repeat_time))
- json_data = json_data * int_repeat_time
- json_data.extend(remaining_json_data)
- else:
- json_data = json_data * int_repeat_time
-
- self.data.extend(json_data)
- self.image_folder.extend([image_folder] * len(json_data))
-
- # TODO: multi process
- for data_item in json_data:
- if 'length' in data_item:
- token_length = data_item['length'] # include image token
- else:
- conversations = '\n'.join(
- [temp['value'] for temp in data_item['conversations']])
- str_length = len(conversations)
-
- if str_length not in self.conv2length_text:
- token_length = self.tokenizer(
- conversations,
- return_tensors='pt',
- padding=False,
- truncation=False,
- ).input_ids.size(1)
- self.conv2length_text[str_length] = token_length
- else:
- token_length = self.conv2length_text[str_length]
-
- if 'image' in data_item and data_item['image'] is not None:
- if 'image_wh' in data_item and data_item[
- 'image_wh'] is not None:
- # more accurate calculation of image token
- image_wh = data_item['image_wh']
- if isinstance(image_wh[0], list):
- image_wh = image_wh[0]
- image_token = total_image_token(
- image_wh, self.min_dynamic_patch,
- self.max_dynamic_patch, self.image_size,
- self.use_thumbnail)
- image_token = self.patch_token * image_token
- else:
- # max_dynamic_patch + use_thumbnail
- image_token = self.patch_token * (
- self.max_dynamic_patch + self.use_thumbnail)
-
- token_length = token_length + image_token
- else:
- token_length = -token_length
-
- self.group_length.append(token_length)
- print_log(
- f'=======total {len(json_data)} samples of {data_file}=======',
- logger='current')
-
- assert len(self.group_length) == len(self.data)
- print_log('end loading data and calc length', logger='current')
- print_log(
- f'=======total {len(self.data)} samples=======', logger='current')
- self._max_refetch = 1000
-
- def __getitem__(self, index):
- for _ in range(self._max_refetch + 1):
- data = self.prepare_data(index)
- # Broken images may cause the returned data to be None
- if data is None:
- index = self._rand_another()
- continue
- return data
-
- def __len__(self):
- return len(self.data)
-
- @property
- def modality_length(self):
- return self.group_length
-
- @property
- def length(self):
- group_length = np.array(self.group_length)
- group_length = np.abs(group_length).tolist()
- return group_length
-
- def prepare_data(self, index):
- data_dict: dict = self.data[index]
- image_folder = self.image_folder[index]
-
- out_data_dict = {}
- if data_dict.get('image', None) is not None:
- image_file = data_dict['image']
- if isinstance(image_file, (list, tuple)):
- assert len(image_file) == 1
- image_file = image_file[0]
-
- try:
- image = self.get_image(os.path.join(image_folder, image_file))
- except Exception as e:
- print(f'Error: {e}', flush=True)
- print_log(f'Error: {e}', logger='current')
- return None
-
- images = dynamic_preprocess(image, self.min_dynamic_patch,
- self.max_dynamic_patch,
- self.image_size, self.use_thumbnail)
- pixel_values = [self.transformer(image) for image in images]
- pixel_values = torch.stack(pixel_values)
- out_data_dict['pixel_values'] = pixel_values
-
- num_image_tokens = pixel_values.shape[0] * self.patch_token
- image_token_str = f'{self.IMG_START_TOKEN}' \
- f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
- f'{self.IMG_END_TOKEN}'
- token_dict = self.get_inputid_labels(data_dict['conversations'],
- image_token_str)
- out_data_dict.update(token_dict)
- else:
- token_dict = self.get_inputid_labels(data_dict['conversations'],
- None)
- out_data_dict.update(token_dict)
- out_data_dict['pixel_values'] = torch.zeros(
- 1, 3, self.image_size, self.image_size)
- return out_data_dict
-
- def _rand_another(self) -> int:
- return np.random.randint(0, len(self.data))
-
- def get_image(self, path):
- if 's3://' in path:
- img_bytes = get(path)
- with io.BytesIO(img_bytes) as buff:
- img = Image.open(buff).convert('RGB')
- return img
- else:
- return Image.open(path).convert('RGB')
-
- def get_inputid_labels(self, conversations, image_token_str) -> dict:
- input = ''
- out_conversation = []
- while conversations and conversations[0]['from'] == 'gpt':
- # Skip the first one if it is from gpt
- conversations = conversations[1:]
- for msg in conversations:
- if msg['from'] == 'human':
- if image_token_str is None and '' in msg['value']:
- warnings.warn(
- f'The current data << {msg["value"]} >> is '
- f'in plain text mode, but '
- 'there are tags present in the data. '
- 'We need to remove the tags.')
- msg['value'] = msg['value'].replace('', '')
- if '' in msg['value']:
- msg['value'] = msg['value'].replace('', '').strip()
- msg['value'] = image_token_str + '\n' + msg['value']
- msg['value'] = msg['value'].strip()
- input += msg['value'].strip()
- elif msg['from'] == 'gpt':
- out_conversation.append({
- 'input': input,
- 'output': msg['value'].strip()
- })
- input = ''
- else:
- raise NotImplementedError
-
- input_ids, labels = [], []
- for i, single_turn_conversation in enumerate(out_conversation):
- input = single_turn_conversation.get('input', '')
- if input is None:
- input = ''
- input_text = self.template.INSTRUCTION.format(
- input=input, round=i + 1)
-
- if i == 0:
- system = self.template.SYSTEM.format(system=self._system)
- input_text = system + input_text
- input_encode = self.tokenizer.encode(
- input_text, add_special_tokens=True)
- else:
- input_encode = self.tokenizer.encode(
- input_text, add_special_tokens=False)
- input_ids += input_encode
- labels += [IGNORE_INDEX] * len(input_encode)
-
- output_text = single_turn_conversation.get('output', '')
- if self.template.get('SUFFIX', None):
- output_text += self.template.SUFFIX
- output_encode = self.tokenizer.encode(
- output_text, add_special_tokens=False)
- input_ids += output_encode
- labels += copy.deepcopy(output_encode)
-
- if len(input_ids) > self.max_length:
- input_ids = input_ids[:self.max_length]
- labels = labels[:self.max_length]
- print_log(
- f'Warning: input_ids length({len(input_ids)}) '
- f'is longer than max_length, cut to {self.max_length}',
- logger='current')
- return {'input_ids': input_ids, 'labels': labels}
diff --git a/code/xtuner/dataset/json_dataset.py b/code/xtuner/dataset/json_dataset.py
deleted file mode 100644
index 1c7ca016300c94d19acb14bf9934d49c156a7987..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/json_dataset.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import json
-import os
-
-from datasets import Dataset, concatenate_datasets
-
-
-def load_json_file(data_files=None, data_dir=None, suffix=None):
- assert (data_files is not None) != (data_dir is not None)
- if data_dir is not None:
- data_files = os.listdir(data_dir)
- data_files = [os.path.join(data_dir, fn) for fn in data_files]
- if suffix is not None:
- data_files = [fp for fp in data_files if fp.endswith(suffix)]
- elif isinstance(data_files, str):
- data_files = [data_files]
-
- dataset_list = []
- for fp in data_files:
- with open(fp, encoding='utf-8') as file:
- data = json.load(file)
- ds = Dataset.from_list(data)
- dataset_list.append(ds)
- dataset = concatenate_datasets(dataset_list)
- return dataset
diff --git a/code/xtuner/dataset/llava.py b/code/xtuner/dataset/llava.py
deleted file mode 100644
index 79d149be1ed6cbf1902028c12c65b7b26640fc58..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/llava.py
+++ /dev/null
@@ -1,846 +0,0 @@
-# import json
-# import logging
-# import os
-# import numpy as np
-# import pandas as pd
-# import torch
-# import h5py
-
-# from datasets import Dataset as HFDataset
-# from datasets import DatasetDict, load_from_disk
-# from mmengine import print_log
-# from torch.utils.data import Dataset, get_worker_info
-
-# from xtuner.registry import BUILDER
-# from .huggingface import process_hf_dataset
-
-# # 映射采样策略到共享整数,便于多进程同步
-# _STRATEGY2ID = {"linspace": 0, "random": 1, "random_full": 2}
-# _ID2STRATEGY = {v: k for k, v in _STRATEGY2ID.items()}
-
-
-# class LLaVADataset(Dataset):
-
-# def __init__(self,
-# image_folder,
-# image_path_list,
-# per_image_length,
-# data_path=None,
-# tokenizer=None,
-# offline_processed_text_folder=None,
-# max_dataset_length=None,
-# dataset_map_fn=None,
-# template_map_fn=None,
-# max_length=2048,
-# pad_image_to_square=False,
-# sample_num=10240,
-# image_feature_prefix='',
-# identifier='',
-# image_feature_suffix='.pt',
-# unwanted_prefix_csv=None,
-# sample_strategy: str = 'linspace', # 新增:默认等距
-# # ---------- DEBUG 选项 ----------
-# debug_max_samples=None,
-# debug_ratio=None,
-# debug_shuffle=True,
-# debug_seed=3407,
-# debug_include_ids=None):
-# super().__init__()
-
-# # ---- 通过共享内存暴露可变控制量,确保 Hook 在主进程修改后,worker 可见 ----
-# self._sample_num_shm = torch.tensor([int(sample_num)], dtype=torch.int32)
-# self._sample_num_shm.share_memory_()
-# self._pil_shm = torch.tensor([int(per_image_length)], dtype=torch.int32)
-# self._pil_shm.share_memory_()
-# if sample_strategy not in _STRATEGY2ID:
-# raise ValueError(f"Unsupported sample_strategy: {sample_strategy}")
-# self._strategy_shm = torch.tensor([_STRATEGY2ID[sample_strategy]], dtype=torch.int32)
-# self._strategy_shm.share_memory_()
-
-# self.pad_image_to_square = pad_image_to_square
-# self.image_feature_prefix = image_feature_prefix
-# self.identifier = identifier
-
-# # debug opts
-# self._dbg_max = debug_max_samples
-# self._dbg_ratio = debug_ratio
-# self._dbg_shuffle = debug_shuffle
-# self._dbg_seed = int(debug_seed)
-# self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None
-
-# assert offline_processed_text_folder or (data_path and tokenizer)
-# if offline_processed_text_folder and data_path:
-# print_log(
-# 'Both `offline_processed_text_folder` and `data_path` are set, '
-# 'and we load dataset from `offline_processed_text_folder` '
-# f'({offline_processed_text_folder})',
-# logger='current', level=logging.WARNING)
-
-# # ---------------------- load text ----------------------
-# if offline_processed_text_folder is not None:
-# ds = load_from_disk(offline_processed_text_folder)
-# if isinstance(ds, DatasetDict):
-# ds = ds.get('train', None) or next(iter(ds.values()))
-# assert isinstance(ds, HFDataset)
-# text_ds = ds
-# text_ds = self._apply_debug_subset_to_hf(text_ds)
-# self.text_data = text_ds
-# else:
-# if data_path.endswith('.json'):
-# json_data = json.load(open(data_path))
-# elif data_path.endswith('.jsonl'):
-# json_data = self._load_jsonl(data_path)
-# else:
-# raise NotImplementedError
-
-# # ---- filter out unwanted prefixes (string/list 都兼容)
-# unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv)
-# original_count = len(json_data)
-# filtered = []
-# for item in json_data:
-# imgs = item.get('image', [])
-# if isinstance(imgs, str):
-# imgs = [imgs]
-# keep = True
-# for img in imgs:
-# if any(pref in img for pref in unwanted_prefixes):
-# keep = False
-# break
-# if keep:
-# filtered.append(item)
-# json_data = filtered
-# print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current')
-
-# # ---- debug: include_ids 优先过滤
-# if self._dbg_include_ids:
-# keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids]
-# print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current')
-# json_data = keep
-
-# # ---- debug: 子集抽样
-# json_data = self._apply_debug_subset_to_list(json_data)
-
-# # id -> str
-# for idx in range(len(json_data)):
-# if isinstance(json_data[idx].get('id'), int):
-# json_data[idx]['id'] = str(json_data[idx]['id'])
-
-# # HF map & template
-# json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
-# self.text_data = process_hf_dataset(
-# dataset=json_data,
-# tokenizer=tokenizer,
-# max_length=max_length,
-# dataset_map_fn=dataset_map_fn,
-# template_map_fn=template_map_fn,
-# split='train',
-# max_dataset_length=max_dataset_length,
-# remove_unused_columns=False,
-# pack_to_max_length=False,
-# with_image_token=True,
-# per_image_length=self.per_image_length)
-
-# # ---------------------- image feature suffix sanity ----------------------
-# if image_feature_suffix not in ['.csv', '.pt', '.h5']:
-# raise ValueError(
-# f'Unsupported image feature suffix: {image_feature_suffix}. '
-# 'Supported suffixes are: .csv, .pt, .h5')
-# self.image_feature_suffix = image_feature_suffix
-
-# self.image_folder = image_folder
-# self.image_path_list = image_path_list
-
-# # ---------------------- shared-backed properties ----------------------
-# @property
-# def sample_num(self) -> int:
-# return int(self._sample_num_shm.item())
-
-# @sample_num.setter
-# def sample_num(self, v: int):
-# self._sample_num_shm.fill_(int(v))
-
-# @property
-# def per_image_length(self) -> int:
-# return int(self._pil_shm.item())
-
-# @per_image_length.setter
-# def per_image_length(self, v: int):
-# self._pil_shm.fill_(int(v))
-
-# @property
-# def sample_strategy(self) -> str:
-# return _ID2STRATEGY[int(self._strategy_shm.item())]
-
-# @sample_strategy.setter
-# def sample_strategy(self, v: str):
-# if v not in _STRATEGY2ID:
-# raise ValueError(f"Unknown sample_strategy: {v}")
-# self._strategy_shm.fill_(_STRATEGY2ID[v])
-
-# # ---------------------- helpers ----------------------
-# def _load_unwanted_prefixes(self, csv_path):
-# unwanted_prefixes = set()
-# if csv_path and os.path.exists(csv_path):
-# print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current')
-# try:
-# df = pd.read_csv(csv_path)
-# unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist())
-# print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current')
-# except Exception as e:
-# print_log(f'Could not read CSV file {csv_path}. Error: {e}',
-# logger='current', level=logging.ERROR)
-# print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING)
-
-# if not unwanted_prefixes:
-# print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING)
-# unwanted_prefixes = {
-# "TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1",
-# "TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1",
-# "TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2",
-# "TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1",
-# "TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1",
-# "TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1",
-# "TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6",
-# "TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2",
-# "TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6",
-# "TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2",
-# "TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1",
-# "TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6",
-# "TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1",
-# "TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1",
-# "TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7",
-# "TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1",
-# "TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8",
-# "TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4",
-# "TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1",
-# "TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1"
-# }
-# return unwanted_prefixes
-
-# def _load_jsonl(self, json_file):
-# with open(json_file) as f:
-# return [json.loads(line) for line in f]
-
-# def _apply_debug_subset_to_list(self, items):
-# if not items:
-# return items
-# n_before = len(items)
-# if self._dbg_include_ids:
-# items = [it for it in items if str(it.get('id')) in self._dbg_include_ids]
-# n_before = len(items)
-# print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current')
-
-# if self._dbg_max is None and self._dbg_ratio is not None:
-# self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
-
-# if self._dbg_max is None:
-# print_log('[DEBUG] dataset full size used.', logger='current')
-# return items
-
-# k = min(int(self._dbg_max), n_before)
-# if k <= 0:
-# return items
-
-# if self._dbg_shuffle:
-# rng = np.random.default_rng(self._dbg_seed)
-# idx = rng.choice(n_before, size=k, replace=False)
-# idx = sorted(idx.tolist())
-# items = [items[i] for i in idx]
-# else:
-# items = items[:k]
-
-# print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used '
-# f'({"random" if self._dbg_shuffle else "head"}).',
-# logger='current')
-# return items
-
-# def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset:
-# n_before = ds.num_rows
-# if self._dbg_include_ids:
-# keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids]
-# ds = ds.select(keep_idx)
-# print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current')
-# n_before = ds.num_rows
-
-# if self._dbg_max is None and self._dbg_ratio is not None:
-# self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
-# if self._dbg_max is None:
-# print_log('[DEBUG] dataset full size used (offline).', logger='current')
-# return ds
-
-# k = min(int(self._dbg_max), n_before)
-# if k <= 0:
-# return ds
-
-# if self._dbg_shuffle:
-# rng = np.random.default_rng(self._dbg_seed)
-# idx = rng.choice(n_before, size=k, replace=False)
-# idx = sorted(idx.tolist())
-# else:
-# idx = list(range(k))
-
-# ds = ds.select(idx)
-# print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used '
-# f'({"random" if self._dbg_shuffle else "head"}).',
-# logger='current')
-# return ds
-
-# # -------- 每个 worker 的 RNG,保证可复现 --------
-# def _rng(self):
-# """Return a numpy Generator seeded per-worker for reproducibility."""
-# wi = get_worker_info()
-# base = self._dbg_seed
-# if wi is None:
-# seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
-# else:
-# seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
-# return np.random.default_rng(seed)
-
-# # -------- 路径解析 --------
-# def _parse_stub(self, image_path: str):
-# norm = os.path.normpath(image_path)
-# parts = norm.split(os.sep)
-# if len(parts) < 2:
-# fname = os.path.splitext(parts[-1])[0]
-# tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown'
-# case = fname
-# else:
-# tumor = parts[-2].lower()
-# case = os.path.splitext(parts[-1])[0]
-# return tumor, case
-
-# # -------- 构造特征路径 --------
-# def _build_feature_path(self, tumor_name: str, case_name: str):
-# if self.image_feature_suffix == ".pt":
-# subdir = "pt_files"
-# elif self.image_feature_suffix == ".csv":
-# subdir = "csv_files"
-# elif self.image_feature_suffix == ".h5":
-# subdir = "h5_files"
-# else:
-# raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}")
-# return os.path.join(
-# self.image_feature_prefix,
-# f"{tumor_name}{self.identifier}",
-# subdir,
-# case_name + self.image_feature_suffix
-# )
-
-# # -------- 选择 patch 索引(支持 linspace / random / random_full) --------
-# def _choose_indices(self, total_rows: int, rng: np.random.Generator):
-# k = self.sample_num
-# if total_rows <= 0:
-# return np.array([], dtype=int)
-
-# strat = self.sample_strategy
-# if strat == "random_full":
-# # 总是返回正好 k 个;不足则有放回
-# replace = total_rows < k
-# idx = rng.choice(total_rows, size=k, replace=replace)
-# return np.sort(idx.astype(int))
-
-# if strat == "random":
-# # 无放回随机;不足则直接用全部(返回 < k 个)
-# if total_rows <= k:
-# return np.arange(total_rows, dtype=int)
-# idx = rng.choice(total_rows, size=k, replace=False)
-# return np.sort(idx.astype(int))
-
-# # 默认:等距 + 抖动;不足则直接全取
-# if total_rows <= k:
-# return np.arange(total_rows, dtype=int)
-# step = total_rows / k
-# jitter = int(rng.integers(0, max(1, int(step))))
-# indices = (np.floor(np.arange(k) * step + jitter)).astype(int)
-# return np.clip(indices, 0, total_rows - 1)
-
-# # ---------------------- rest of class ----------------------
-# @property
-# def modality_length(self):
-# length_list = []
-# for data_dict in self.text_data:
-# cur_len = len(data_dict['input_ids'])
-# image = data_dict.get('image', None)
-# if image is None:
-# cur_len = -cur_len
-# else:
-# n_images = 1 if isinstance(image, str) else len(image)
-# cur_len = cur_len - n_images + self.per_image_length * n_images
-# length_list.append(cur_len)
-# return length_list
-
-# def __len__(self):
-# return len(self.text_data)
-
-# def __getitem__(self, index):
-
-# wi = get_worker_info()
-# if not hasattr(self, "_printed_once"):
-# print_log(
-# f"[LLaVADataset] worker={wi.id if wi else -1} "
-# f"effective_k={self.sample_num} strategy={self.sample_strategy}",
-# logger="current"
-# )
-# self._printed_once = True
-
-# data_dict = self.text_data[index]
-# if data_dict.get('image', None) is None:
-# return data_dict
-
-# image_list = data_dict['image']
-# if isinstance(image_list, str):
-# image_list = [image_list]
-
-# images, coords_list = [], []
-# rng = self._rng()
-
-# for image_file in image_list:
-# tumor_name, case_name = self._parse_stub(image_file)
-# train_image_file = self._build_feature_path(tumor_name, case_name)
-
-# if train_image_file.endswith('.csv'):
-# if not os.path.exists(train_image_file):
-# raise FileNotFoundError(train_image_file)
-# feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32)
-# total_rows = len(feats_df)
-# idx = self._choose_indices(total_rows, rng)
-# feats = torch.from_numpy(feats_df.to_numpy()[idx]).float()
-# images.append(feats)
-# coords_list.append(None)
-
-# elif train_image_file.endswith('.pt'):
-# if not os.path.exists(train_image_file):
-# raise FileNotFoundError(train_image_file)
-# feats_np = torch.load(train_image_file, map_location='cpu')
-# if isinstance(feats_np, torch.Tensor):
-# feats_np = feats_np.cpu().numpy()
-# feats_np = feats_np.astype(np.float32, copy=False)
-# total_rows = feats_np.shape[0]
-# idx = self._choose_indices(total_rows, rng)
-# feats = torch.from_numpy(feats_np[idx]).float()
-# images.append(feats)
-# coords_list.append(None)
-
-# elif train_image_file.endswith('.h5'):
-# if not os.path.exists(train_image_file):
-# raise FileNotFoundError(train_image_file)
-# with h5py.File(train_image_file, 'r') as f:
-# feats_np = f['features'][:]
-# coords_np = f['coords'][:]
-# if feats_np.shape[0] != coords_np.shape[0]:
-# raise ValueError(
-# f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) "
-# f"for {train_image_file}")
-# feats_np = feats_np.astype(np.float32, copy=False)
-# total_rows = feats_np.shape[0]
-# idx = self._choose_indices(total_rows, rng)
-# feats = torch.from_numpy(feats_np[idx]).float()
-# coords = torch.from_numpy(coords_np[idx]).long()
-# images.append(feats)
-# coords_list.append(coords)
-
-# else:
-# raise ValueError(f'Unsupported file: {train_image_file}')
-
-# data_dict['pixel_values'] = images
-# if any(c is not None for c in coords_list):
-# coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long)
-# for c in coords_list]
-# data_dict['coords'] = coords_list
-# return data_dict
-
-from __future__ import annotations
-import json, logging, os
-import numpy as np
-import pandas as pd
-import torch, h5py
-
-from datasets import Dataset as HFDataset
-from datasets import DatasetDict, load_from_disk
-from mmengine import print_log
-from torch.utils.data import Dataset, get_worker_info
-
-from xtuner.registry import BUILDER
-from .huggingface import process_hf_dataset
-
-
-class LLaVADataset(Dataset):
- def __init__(self,
- image_folder,
- image_path_list,
- per_image_length,
- data_path=None,
- tokenizer=None,
- offline_processed_text_folder=None,
- max_dataset_length=None,
- dataset_map_fn=None,
- template_map_fn=None,
- max_length=2048,
- pad_image_to_square=False,
- sample_num=10240,
- image_feature_prefix='',
- identifier='',
- image_feature_suffix='.pt',
- unwanted_prefix_csv=None,
- sample_strategy='linspace',
- # ---------- DEBUG ----------
- debug_max_samples=None,
- debug_ratio=None,
- debug_shuffle=True,
- debug_seed=3407,
- debug_include_ids=None):
- super().__init__()
-
- self.sample_num = int(sample_num)
- self.per_image_length = int(per_image_length)
- self.pad_image_to_square = pad_image_to_square
- self.image_feature_prefix = image_feature_prefix
- self.identifier = identifier
- self.sample_strategy = sample_strategy # 'linspace' | 'random' | 'random_full'
-
- # debug opts
- self._dbg_max = debug_max_samples
- self._dbg_ratio = debug_ratio
- self._dbg_shuffle = debug_shuffle
- self._dbg_seed = int(debug_seed)
- self._dbg_include_ids = set(debug_include_ids) if debug_include_ids else None
-
- assert offline_processed_text_folder or (data_path and tokenizer)
- if offline_processed_text_folder and data_path:
- print_log(
- 'Both `offline_processed_text_folder` and `data_path` are set, '
- 'and we load dataset from `offline_processed_text_folder` '
- f'({offline_processed_text_folder})',
- logger='current', level=logging.WARNING)
-
- # ---------------------- load text ----------------------
- if offline_processed_text_folder is not None:
- ds = load_from_disk(offline_processed_text_folder)
- if isinstance(ds, DatasetDict):
- ds = ds.get('train', None) or next(iter(ds.values()))
- assert isinstance(ds, HFDataset)
- text_ds = ds
- text_ds = self._apply_debug_subset_to_hf(text_ds)
- self.text_data = text_ds
- else:
- if data_path.endswith('.json'):
- json_data = json.load(open(data_path))
- elif data_path.endswith('.jsonl'):
- json_data = self._load_jsonl(data_path)
- else:
- raise NotImplementedError
-
- # ---- unwanted prefixes
- unwanted_prefixes = self._load_unwanted_prefixes(unwanted_prefix_csv)
- original_count = len(json_data)
- filtered = []
- for item in json_data:
- imgs = item.get('image', [])
- if isinstance(imgs, str):
- imgs = [imgs]
- keep = True
- for img in imgs:
- if any(pref in img for pref in unwanted_prefixes):
- keep = False
- break
- if keep:
- filtered.append(item)
- json_data = filtered
- print_log(f'Filtered out {original_count - len(json_data)} samples.', logger='current')
-
- # ---- debug include_ids
- if self._dbg_include_ids:
- keep = [it for it in json_data if str(it.get('id')) in self._dbg_include_ids]
- print_log(f'[DEBUG] include_ids -> keep {len(keep)}/{len(json_data)}', logger='current')
- json_data = keep
-
- # ---- debug subset
- json_data = self._apply_debug_subset_to_list(json_data)
-
- # id -> str
- for idx in range(len(json_data)):
- if isinstance(json_data[idx].get('id'), int):
- json_data[idx]['id'] = str(json_data[idx]['id'])
-
- # HF map & template
- json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
- self.text_data = process_hf_dataset(
- dataset=json_data,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=dataset_map_fn,
- template_map_fn=template_map_fn,
- split='train',
- max_dataset_length=max_dataset_length,
- remove_unused_columns=False,
- pack_to_max_length=False,
- with_image_token=True,
- per_image_length=self.per_image_length)
-
- # ---------------------- image feature suffix sanity ----------------------
- if image_feature_suffix not in ['.csv', '.pt', '.h5']:
- raise ValueError(
- f'Unsupported image feature suffix: {image_feature_suffix}. '
- 'Supported suffixes are: .csv, .pt, .h5')
- self.image_feature_suffix = image_feature_suffix
-
- self.image_folder = image_folder
- self.image_path_list = image_path_list
-
- # ---------------------- helpers ----------------------
- def _load_unwanted_prefixes(self, csv_path):
- unwanted_prefixes = set()
- if csv_path and os.path.exists(csv_path):
- print_log(f'Loading unwanted prefixes from: {csv_path}', logger='current')
- try:
- df = pd.read_csv(csv_path)
- unwanted_prefixes = set(df.iloc[:, 0].astype(str).tolist())
- print_log(f'Loaded {len(unwanted_prefixes)} prefixes to filter out.', logger='current')
- except Exception as e:
- print_log(f'Could not read CSV file {csv_path}. Error: {e}',
- logger='current', level=logging.ERROR)
- print_log('Falling back to hardcoded list.', logger='current', level=logging.WARNING)
-
- if not unwanted_prefixes:
- print_log('Using hardcoded unwanted prefix list.', logger='current', level=logging.WARNING)
- unwanted_prefixes = {
- "TCGA-HT-7476-01Z-00-DX2", "TCGA-44-7661-01Z-00-DX1", "TCGA-DB-A64V-01Z-00-DX1",
- "TCGA-CS-4938-01Z-00-DX1", "TCGA-DB-5273-01Z-00-DX2", "TCGA-DB-5278-01Z-00-DX1",
- "TCGA-DB-A4XA-01Z-00-DX1", "TCGA-DB-A4XB-01Z-00-DX1", "TCGA-DB-A4XC-01Z-00-DX2",
- "TCGA-DU-5849-01Z-00-DX1", "TCGA-DU-6399-01Z-00-DX1", "TCGA-DU-7006-01Z-00-DX1",
- "TCGA-DU-7013-01Z-00-DX1", "TCGA-DU-8165-01Z-00-DX1", "TCGA-DU-A76O-01Z-00-DX1",
- "TCGA-DU-A7TG-01Z-00-DX1", "TCGA-E1-A7YM-01Z-00-DX1", "TCGA-E1-A7Z6-01Z-00-DX1",
- "TCGA-FG-A6J3-01Z-00-DX2", "TCGA-HT-7467-01Z-00-DX2", "TCGA-HT-7468-01Z-00-DX6",
- "TCGA-HT-7470-01Z-00-DX4", "TCGA-HT-7470-01Z-00-DX9", "TCGA-HT-7473-01Z-00-DX2",
- "TCGA-HT-7475-01Z-00-DX5", "TCGA-HT-7481-01Z-00-DX1", "TCGA-HT-7482-01Z-00-DX6",
- "TCGA-HT-7601-01Z-00-DX3", "TCGA-HT-7607-01Z-00-DX10", "TCGA-HT-7608-01Z-00-DX2",
- "TCGA-HT-7616-01Z-00-DX1", "TCGA-HT-7684-01Z-00-DX2", "TCGA-HT-7689-01Z-00-DX1",
- "TCGA-HT-7690-01Z-00-DX4", "TCGA-HT-7855-01Z-00-DX1", "TCGA-HT-7856-01Z-00-DX6",
- "TCGA-HT-7874-01Z-00-DX2", "TCGA-HT-8105-01Z-00-DX1", "TCGA-HT-8108-01Z-00-DX1",
- "TCGA-HT-A74O-01Z-00-DX1", "TCGA-IK-8125-01Z-00-DX1", "TCGA-P5-A72X-01Z-00-DX1",
- "TCGA-QH-A65R-01Z-00-DX1", "TCGA-QH-A870-01Z-00-DX1", "TCGA-R8-A6MO-01Z-00-DX7",
- "TCGA-S9-A6TX-01Z-00-DX1", "TCGA-TM-A84I-01Z-00-DX1", "TCGA-TM-A84L-01Z-00-DX1",
- "TCGA-TM-A84O-01Z-00-DX1", "TCGA-TQ-A7RP-01Z-00-DX1", "TCGA-VM-A8C8-01Z-00-DX8",
- "TCGA-VM-A8C9-01Z-00-DX9", "TCGA-VM-A8CA-01Z-00-DX4", "TCGA-VM-A8CB-01Z-00-DX4",
- "TCGA-VM-A8CB-01Z-00-DX5", "TCGA-VM-A8CD-01Z-00-DX6", "TCGA-VM-A8CE-01Z-00-DX1",
- "TCGA-VM-A8CE-01Z-00-DX7", "TCGA-QK-A8ZB-01Z-00-DX1"
- }
- return unwanted_prefixes
-
- def _load_jsonl(self, json_file):
- with open(json_file) as f:
- return [json.loads(line) for line in f]
-
- def _apply_debug_subset_to_list(self, items):
- if not items:
- return items
- n_before = len(items)
- if self._dbg_include_ids:
- items = [it for it in items if str(it.get('id')) in self._dbg_include_ids]
- n_before = len(items)
- print_log(f'[DEBUG] include_ids -> keep {n_before}', logger='current')
-
- if self._dbg_max is None and self._dbg_ratio is not None:
- self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
-
- if self._dbg_max is None:
- print_log('[DEBUG] dataset full size used.', logger='current')
- return items
-
- k = min(int(self._dbg_max), n_before)
- if k <= 0:
- return items
-
- if self._dbg_shuffle:
- rng = np.random.default_rng(self._dbg_seed)
- idx = rng.choice(n_before, size=k, replace=False)
- idx = sorted(idx.tolist())
- items = [items[i] for i in idx]
- else:
- items = items[:k]
-
- print_log(f'[DEBUG] subset: {len(items)}/{n_before} samples used '
- f'({"random" if self._dbg_shuffle else "head"}).',
- logger='current')
- return items
-
- def _apply_debug_subset_to_hf(self, ds: HFDataset) -> HFDataset:
- n_before = ds.num_rows
- if self._dbg_include_ids:
- keep_idx = [i for i, ex in enumerate(ds) if str(ex.get('id')) in self._dbg_include_ids]
- ds = ds.select(keep_idx)
- print_log(f'[DEBUG] include_ids -> keep {ds.num_rows}/{n_before}', logger='current')
- n_before = ds.num_rows
-
- if self._dbg_max is None and self._dbg_ratio is not None:
- self._dbg_max = max(1, int(round(n_before * float(self._dbg_ratio))))
- if self._dbg_max is None:
- print_log('[DEBUG] dataset full size used (offline).', logger='current')
- return ds
-
- k = min(int(self._dbg_max), n_before)
- if k <= 0:
- return ds
-
- if self._dbg_shuffle:
- rng = np.random.default_rng(self._dbg_seed)
- idx = rng.choice(n_before, size=k, replace=False)
- idx = sorted(idx.tolist())
- else:
- idx = list(range(k))
-
- ds = ds.select(idx)
- print_log(f'[DEBUG] subset (offline): {ds.num_rows}/{n_before} samples used '
- f'({"random" if self._dbg_shuffle else "head"}).',
- logger='current')
- return ds
-
- # -------- per-worker RNG --------
- def _rng(self):
- wi = get_worker_info()
- base = self._dbg_seed
- if wi is None:
- seed = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
- else:
- seed = (base + wi.id + (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
- return np.random.default_rng(seed)
-
- # -------- path parsing --------
- def _parse_stub(self, image_path: str):
- norm = os.path.normpath(image_path)
- parts = norm.split(os.sep)
- if len(parts) < 2:
- fname = os.path.splitext(parts[-1])[0]
- tumor = fname.split('-')[0].lower() if '-' in fname else 'unknown'
- case = fname
- else:
- tumor = parts[-2].lower()
- case = os.path.splitext(parts[-1])[0]
- return tumor, case
-
- def _build_feature_path(self, tumor_name: str, case_name: str):
- if self.image_feature_suffix == ".pt":
- subdir = "pt_files"
- elif self.image_feature_suffix == ".csv":
- subdir = "csv_files"
- elif self.image_feature_suffix == ".h5":
- subdir = "h5_files"
- else:
- raise ValueError(f"Unknown feature suffix: {self.image_feature_suffix}")
- return os.path.join(
- self.image_feature_prefix,
- f"{tumor_name}{self.identifier}",
- subdir,
- case_name + self.image_feature_suffix
- )
-
- # -------- choose patch indices --------
- def _choose_indices(self, total_rows: int, rng: np.random.Generator):
- k = self.sample_num
- if total_rows <= 0:
- return np.array([], dtype=int)
-
- if self.sample_strategy == "random_full":
- # Always exactly k rows; with replacement if needed
- replace = total_rows < k
- idx = rng.choice(total_rows, size=k, replace=replace)
- return np.sort(idx.astype(int))
-
- if self.sample_strategy == "random":
- if total_rows <= k:
- return np.arange(total_rows, dtype=int)
- idx = rng.choice(total_rows, size=k, replace=False)
- return np.sort(idx.astype(int))
-
- # linspace
- if total_rows <= k:
- return np.arange(total_rows, dtype=int)
- step = total_rows / k
- jitter = int(rng.integers(0, max(1, int(step))))
- indices = (np.floor(np.arange(k) * step + jitter)).astype(int)
- return np.clip(indices, 0, total_rows - 1)
-
- # ---------------------- rest of class ----------------------
- @property
- def modality_length(self):
- length_list = []
- for data_dict in self.text_data:
- cur_len = len(data_dict['input_ids'])
- image = data_dict.get('image', None)
- if image is None:
- cur_len = -cur_len
- else:
- n_images = 1 if isinstance(image, str) else len(image)
- cur_len = cur_len - n_images + self.per_image_length * n_images
- length_list.append(cur_len)
- return length_list
-
- def __len__(self):
- return len(self.text_data)
-
- def __getitem__(self, index):
- data_dict = self.text_data[index]
- if data_dict.get('image', None) is None:
- return data_dict
-
- image_list = data_dict['image']
- if isinstance(image_list, str):
- image_list = [image_list]
-
- images, coords_list = [], []
- rng = self._rng()
-
- for image_file in image_list:
- tumor_name, case_name = self._parse_stub(image_file)
- train_image_file = self._build_feature_path(tumor_name, case_name)
-
- if train_image_file.endswith('.csv'):
- if not os.path.exists(train_image_file):
- raise FileNotFoundError(train_image_file)
- feats_df = pd.read_csv(train_image_file, usecols=range(512), dtype=np.float32)
- total_rows = len(feats_df)
- idx = self._choose_indices(total_rows, rng)
- feats = torch.from_numpy(feats_df.to_numpy()[idx]).float()
- images.append(feats)
- coords_list.append(None)
-
- elif train_image_file.endswith('.pt'):
- if not os.path.exists(train_image_file):
- raise FileNotFoundError(train_image_file)
- feats_np = torch.load(train_image_file, map_location='cpu')
- if isinstance(feats_np, torch.Tensor):
- feats_np = feats_np.cpu().numpy()
- feats_np = feats_np.astype(np.float32, copy=False)
- total_rows = feats_np.shape[0]
- idx = self._choose_indices(total_rows, rng)
- feats = torch.from_numpy(feats_np[idx]).float()
- images.append(feats)
- coords_list.append(None)
-
- elif train_image_file.endswith('.h5'):
- if not os.path.exists(train_image_file):
- raise FileNotFoundError(train_image_file)
- with h5py.File(train_image_file, 'r') as f:
- feats_np = f['features'][:]
- coords_np = f['coords'][:]
- if feats_np.shape[0] != coords_np.shape[0]:
- raise ValueError(
- f"Mismatch rows in features ({feats_np.shape[0]}) vs coords ({coords_np.shape[0]}) "
- f"for {train_image_file}")
- feats_np = feats_np.astype(np.float32, copy=False)
- total_rows = feats_np.shape[0]
- idx = self._choose_indices(total_rows, rng)
- feats = torch.from_numpy(feats_np[idx]).float()
- coords = torch.from_numpy(coords_np[idx]).long()
- images.append(feats)
- coords_list.append(coords)
-
- else:
- raise ValueError(f'Unsupported file: {train_image_file}')
-
- data_dict['pixel_values'] = images
- if any(c is not None for c in coords_list):
- coords_list = [c if c is not None else torch.empty(0, 2, dtype=torch.long)
- for c in coords_list]
- data_dict['coords'] = coords_list
- return data_dict
\ No newline at end of file
diff --git a/code/xtuner/dataset/map_fns/__init__.py b/code/xtuner/dataset/map_fns/__init__.py
deleted file mode 100644
index 4a488c53eab57eedcd0437c2f239faec445292cf..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .dataset_map_fns import * # noqa: F401, F403
-from .template_map_fn import template_map_fn # noqa: F401
-from .template_map_fn import template_map_fn_factory # noqa: F401
diff --git a/code/xtuner/dataset/map_fns/__pycache__/__init__.cpython-311.pyc b/code/xtuner/dataset/map_fns/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 594962f21ca3a28770a56a021506760776ad6b05..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/__pycache__/template_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/__pycache__/template_map_fn.cpython-311.pyc
deleted file mode 100644
index 5841213ebb72c4b63e9a48314a3fc643bff47468..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/__pycache__/template_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__init__.py b/code/xtuner/dataset/map_fns/dataset_map_fns/__init__.py
deleted file mode 100644
index 449b7b4f20efec582e419fb15f7fcc45f200a585..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/__init__.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .alpaca_map_fn import alpaca_map_fn
-from .alpaca_zh_map_fn import alpaca_zh_map_fn
-from .arxiv_map_fn import arxiv_map_fn
-from .code_alpaca_map_fn import code_alpaca_map_fn
-from .colors_map_fn import colors_map_fn
-from .crime_kg_assitant_map_fn import crime_kg_assitant_map_fn
-from .default_map_fn import default_map_fn
-from .law_reference_map_fn import law_reference_map_fn
-from .llava_map_fn import llava_image_only_map_fn, llava_map_fn
-from .medical_map_fn import medical_map_fn
-from .msagent_map_fn import msagent_react_map_fn
-from .oasst1_map_fn import oasst1_map_fn
-from .openai_map_fn import openai_map_fn
-from .openorca_map_fn import openorca_map_fn
-from .pretrain_map_fn import pretrain_map_fn
-from .sql_map_fn import sql_map_fn
-from .stack_exchange_map_fn import stack_exchange_map_fn
-from .tiny_codes_map_fn import tiny_codes_map_fn
-from .wizardlm_map_fn import wizardlm_map_fn
-
-DATASET_FORMAT_MAPPING = dict(
- alpaca=alpaca_map_fn,
- alpaca_zh=alpaca_zh_map_fn,
- arxiv=arxiv_map_fn,
- code_alpaca=code_alpaca_map_fn,
- colors=colors_map_fn,
- crime_kg_assitan=crime_kg_assitant_map_fn,
- default=default_map_fn,
- law_reference=law_reference_map_fn,
- llava_image_only=llava_image_only_map_fn,
- llava=llava_map_fn,
- medical=medical_map_fn,
- msagent_react=msagent_react_map_fn,
- oasst1=oasst1_map_fn,
- openai=openai_map_fn,
- openorca=openorca_map_fn,
- pretrain=pretrain_map_fn,
- sql=sql_map_fn,
- stack_exchange=stack_exchange_map_fn,
- tiny_codes=tiny_codes_map_fn,
- wizardlm=wizardlm_map_fn,
-)
-
-__all__ = [
- 'alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn', 'arxiv_map_fn',
- 'medical_map_fn', 'openorca_map_fn', 'code_alpaca_map_fn',
- 'tiny_codes_map_fn', 'colors_map_fn', 'law_reference_map_fn',
- 'crime_kg_assitant_map_fn', 'sql_map_fn', 'openai_map_fn',
- 'wizardlm_map_fn', 'stack_exchange_map_fn', 'msagent_react_map_fn',
- 'pretrain_map_fn', 'default_map_fn', 'llava_image_only_map_fn',
- 'llava_map_fn', 'DATASET_FORMAT_MAPPING'
-]
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/__init__.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index ebaed7f750a40b16173fd47c1c713fed6b379faa..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_map_fn.cpython-311.pyc
deleted file mode 100644
index bccd7e7f901684c9ae26cba56f9a52f45b95db7b..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_zh_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_zh_map_fn.cpython-311.pyc
deleted file mode 100644
index dfea0581e62115db33b878f6984a16ada154bd2a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/alpaca_zh_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/arxiv_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/arxiv_map_fn.cpython-311.pyc
deleted file mode 100644
index 20aa34fcf2b34dbd87aa0d3d198d1c3637e48b59..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/arxiv_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/code_alpaca_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/code_alpaca_map_fn.cpython-311.pyc
deleted file mode 100644
index 8f9f466c9a348482108679ba5d8999a954f96952..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/code_alpaca_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/colors_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/colors_map_fn.cpython-311.pyc
deleted file mode 100644
index 89df6c650fe11de1316deaf1abc7e584636d3957..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/colors_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/crime_kg_assitant_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/crime_kg_assitant_map_fn.cpython-311.pyc
deleted file mode 100644
index 10933ba266b0e3ddc6c034d8d11f8196e825b846..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/crime_kg_assitant_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/default_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/default_map_fn.cpython-311.pyc
deleted file mode 100644
index 4d278285a0011b9766a5560167f2018fa4531df9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/default_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/law_reference_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/law_reference_map_fn.cpython-311.pyc
deleted file mode 100644
index 0b0d53b2c5b38e57da332a5beb2f897da9f81526..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/law_reference_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/llava_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/llava_map_fn.cpython-311.pyc
deleted file mode 100644
index b72076618fe639ad0732c759b7fa8d1fe89d475d..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/llava_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/medical_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/medical_map_fn.cpython-311.pyc
deleted file mode 100644
index 4971400ededb296a793232b82e75ea84577cc22d..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/medical_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/msagent_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/msagent_map_fn.cpython-311.pyc
deleted file mode 100644
index c167a8e29106bb1558c406b66987f29a06c38d5a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/msagent_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/oasst1_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/oasst1_map_fn.cpython-311.pyc
deleted file mode 100644
index f9bf83a91c7cfdfeb5547e857989453379bde731..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/oasst1_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openai_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openai_map_fn.cpython-311.pyc
deleted file mode 100644
index e9c6bcc74a5c14054c4e810ad782022d6c90f245..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openai_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openorca_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openorca_map_fn.cpython-311.pyc
deleted file mode 100644
index 9dd50a84832cba26953de6724af0cfb0ea6974df..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/openorca_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/pretrain_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/pretrain_map_fn.cpython-311.pyc
deleted file mode 100644
index c36114b4fd19956250e42fc232ccdb35c59e1ec7..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/pretrain_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/sql_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/sql_map_fn.cpython-311.pyc
deleted file mode 100644
index 66b23cf719f758360e58f959b3903f419afb93e9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/sql_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/stack_exchange_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/stack_exchange_map_fn.cpython-311.pyc
deleted file mode 100644
index f50e189b861182e7723d37c1289f415e92308564..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/stack_exchange_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/tiny_codes_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/tiny_codes_map_fn.cpython-311.pyc
deleted file mode 100644
index 89bb50d1fcb0611cc79213cd945db1000b25f808..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/tiny_codes_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/wizardlm_map_fn.cpython-311.pyc b/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/wizardlm_map_fn.cpython-311.pyc
deleted file mode 100644
index 69f2b9b9014dcbe24f4fd112ca1336c5f69454c4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/dataset/map_fns/dataset_map_fns/__pycache__/wizardlm_map_fn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_map_fn.py
deleted file mode 100644
index d64ac3a1cb6f2d5ee5c84b2f5cb08f84d5001ac5..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_map_fn.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-
-
-def alpaca_map_fn(example):
- if example.get('output') == '':
- return {'conversation': []}
- else:
- return {
- 'conversation': [{
- 'input': f"{example['instruction']}\n{example['input']}",
- 'output': example['output']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_zh_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_zh_map_fn.py
deleted file mode 100644
index 5e17cfa048325af7feadc1fd0452481d65b64cd8..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/alpaca_zh_map_fn.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-
-
-def alpaca_zh_map_fn(example):
- return {
- 'conversation': [{
- 'input': f"{example['instruction_zh']}\n{example['input_zh']}",
- 'output': example['output_zh']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/arxiv_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/arxiv_map_fn.py
deleted file mode 100644
index 52bcc4e341708d51d474a3d9db6dcc2ad65df454..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/arxiv_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def arxiv_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.arxiv_gentile,
- 'input': example['abstract'],
- 'output': example['title']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/code_alpaca_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/code_alpaca_map_fn.py
deleted file mode 100644
index ece86ff209807d6e8a555eef95a3205d62aa5144..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/code_alpaca_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def code_alpaca_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.coder,
- 'input': example['prompt'],
- 'output': example['completion']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/colors_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/colors_map_fn.py
deleted file mode 100644
index 17d08bf207cc02d74c2833f1d24da7962e4cd629..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/colors_map_fn.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def colors_map_fn(example):
- desc = ':'.join(example['description'].split(':')[1:]).strip()
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.colorist,
- 'input': desc,
- 'output': example['color']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py
deleted file mode 100644
index b7511a98d94d53aea340a216d9f323c9ae166a41..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def crime_kg_assitant_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.lawyer,
- 'input': example['input'],
- 'output': example['output']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/default_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/default_map_fn.py
deleted file mode 100644
index 0424b884839cd20168ef9c8d26e4363eb8850503..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/default_map_fn.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def default_map_fn(example):
- return {
- 'conversation': [{
- 'input': example['input'],
- 'output': example['output']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/law_reference_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/law_reference_map_fn.py
deleted file mode 100644
index 297086fa082c9c045e6f67af4d74568029b4ffd6..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/law_reference_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def law_reference_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.lawyer,
- 'input': example['question'],
- 'output': example['answer']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/llava_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/llava_map_fn.py
deleted file mode 100644
index a08ca395b6c4fd208a944d97e98e94fa235c15e4..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/llava_map_fn.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-
-
-def llava_image_only_map_fn(example):
- # input contains the DEFAULT_IMAGE_TOKEN only
- messages = example['conversations']
- input = ''
- conversation = []
- while messages and messages[0]['from'] == 'gpt':
- # Skip the first one if it is from gpt
- messages = messages[1:]
- for msg in messages:
- if msg['from'] == 'human':
- assert DEFAULT_IMAGE_TOKEN in msg['value']
- input += DEFAULT_IMAGE_TOKEN
- elif msg['from'] == 'gpt':
- conversation.append({'input': input, 'output': msg['value']})
- input = ''
- else:
- raise NotImplementedError
- return {'conversation': conversation}
-
-
-def llava_map_fn(example):
- messages = example['conversations']
- input = ''
- conversation = []
- while messages and messages[0]['from'] == 'gpt':
- # Skip the first one if it is from gpt
- messages = messages[1:]
- for msg in messages:
- if msg['from'] == 'human':
- if DEFAULT_IMAGE_TOKEN in msg['value']:
- msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
- '').strip()
- msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
- msg['value'] = msg['value'].strip()
- input += msg['value']
-
- elif msg['from'] == 'gpt':
- conversation.append({'input': input, 'output': msg['value']})
- input = ''
- else:
- raise NotImplementedError
- return {'conversation': conversation}
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/medical_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/medical_map_fn.py
deleted file mode 100644
index 60a955454bee80e283ac950ef561e642affc6eef..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/medical_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def medical_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.medical,
- 'input': '{instruction}\n{input}'.format(**example),
- 'output': example['output']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/msagent_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/msagent_map_fn.py
deleted file mode 100644
index fef8b1c5c680b58bf4a6817a6881b1adb021b3f4..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/msagent_map_fn.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import json
-import re
-
-think_regex = r'(.*?)(<\|startofthink\|\>)(.*?)(<\|endofthink\|\>)'
-exec_regex = r'(<\|startofexec\|\>)(.*?)(<\|endofexec\|\>)(.*?)$'
-
-
-def replace_think(match):
- out_text = ''
- if match.group(1).strip() != '':
- out_text += f'Thought:{match.group(1).strip()}\n'
- think_text = match.group(3).replace('```JSON',
- '').replace('```',
- '').replace('\n', '')
- think_json = json.loads(think_text)
- out_text += (f"Action:{think_json['api_name']}\n"
- f"Action Input:{think_json['parameters']}\n")
- return out_text
-
-
-def replace_exec(match):
- out_text = ''
- exec_text = match.group(2).replace('```JSON',
- '').replace('```',
- '').replace('\n', '')
- exec_json = json.loads(exec_text)
- out_text += f'Response:{exec_json}\n'
- if match.group(4).strip() != '':
- out_text += f'Final Answer:{match.group(4).strip()}\n'
- return out_text
-
-
-def extract_json_objects(text, decoder=json.JSONDecoder()):
- pos = 0
- results = []
- while True:
- match = text.find('{', pos)
- if match == -1:
- break
- try:
- result, index = decoder.raw_decode(text[match:])
- if 'name' in result and 'description' in result:
- results.append(result)
- pos = match + index
- else:
- pos = match + 1
- except ValueError:
- pos = match + 1
- return results
-
-
-def msagent_react_map_fn(example):
- text = example['conversations']
- if isinstance(text, str):
- text = eval(text)
- if len(text) < 2: # Filter out invalid data
- return {'conversation': []}
- conversation = []
- system_text = ''
- input_text = ''
- for t in text:
- if t['from'] == 'system':
- system_text += '你是一个可以调用外部工具的助手,可以使用的工具包括:\n'
- json_objects = extract_json_objects(t['value'])
- api_dict = {}
- for obj in json_objects:
- api_dict[obj['name']] = obj['description']
- try:
- params = {
- i['name']: i['description']
- for i in obj['paths'][0]['parameters']
- }
- api_dict[obj['name']] += f'\n输入参数: {params}'
- except Exception:
- pass
- system_text += f'{api_dict}\n'
- system_text += (
- '如果使用工具请遵循以下格式回复:\n```\n'
- 'Thought:思考你当前步骤需要解决什么问题,是否需要使用工具\n'
- f'Action:工具名称,你的工具必须从 [{str(list(api_dict.keys()))}] 选择\n'
- 'Action Input:工具输入参数\n```\n工具返回按照以下格式回复:\n```\n'
- 'Response:调用工具后的结果\n```\n如果你已经知道了答案,或者你不需要工具,'
- '请遵循以下格式回复\n```\n'
- 'Thought:给出最终答案的思考过程\n'
- 'Final Answer:最终答案\n```\n开始!\n')
- elif t['from'] == 'user':
- input_text += f"{t['value']}\n"
- elif t['from'] == 'assistant':
- output = t['value']
- output_response = None
- try:
- if '<|startofexec|>' in output:
- output, output_response = output.split('<|startofexec|>')
- output_response = '<|startofexec|>' + output_response
- output, think_cnt = re.subn(
- think_regex, replace_think, output, flags=re.DOTALL)
- except Exception:
- return {'conversation': []}
-
- if think_cnt == 0:
- output = f'Final Answer:{output}\n'
- else:
- output = f'{output}\n'
- conversation.append({
- 'system': system_text,
- 'input': input_text,
- 'output': output
- })
- system_text = ''
- input_text = ''
- if output_response is not None:
- try:
- output_response, exec_cnt = re.subn(
- exec_regex,
- replace_exec,
- output_response,
- flags=re.DOTALL)
- if 'Final Answer:' in output_response:
- output_response, output_answer = output_response.split(
- 'Final Answer:')
- output_answer = 'Final Answer:' + output_answer
- conversation.append({
- 'system': output_response,
- 'output': output_answer
- })
- except Exception:
- pass
- return {'conversation': conversation}
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/oasst1_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/oasst1_map_fn.py
deleted file mode 100644
index e1e13a01525c8beacc03cc27bb36745dbe63da58..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/oasst1_map_fn.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def oasst1_map_fn(example):
- r"""Example before preprocessing:
- example['text'] = '### Human: Can you explain xxx'
- '### Assistant: Sure! xxx'
- '### Human: I didn't understand how xxx'
- '### Assistant: It has to do with a process xxx.'
-
- Example after preprocessing:
- example['conversation'] = [
- {
- 'input': 'Can you explain xxx',
- 'output': 'Sure! xxx'
- },
- {
- 'input': 'I didn't understand how xxx',
- 'output': 'It has to do with a process xxx.'
- }
- ]
- """
- data = []
- for sentence in example['text'].strip().split('###'):
- sentence = sentence.strip()
- if sentence[:6] == 'Human:':
- data.append(sentence[6:].strip())
- elif sentence[:10] == 'Assistant:':
- data.append(sentence[10:].strip())
- if len(data) % 2:
- # The last round of conversation solely consists of input
- # without any output.
- # Discard the input part of the last round, as this part is ignored in
- # the loss calculation.
- data.pop()
- conversation = []
- for i in range(0, len(data), 2):
- single_turn_conversation = {'input': data[i], 'output': data[i + 1]}
- conversation.append(single_turn_conversation)
- return {'conversation': conversation}
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/openai_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/openai_map_fn.py
deleted file mode 100644
index 468e738f707e0ecae75e89e6a18b91f39b466d56..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/openai_map_fn.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def openai_map_fn(example):
- """
- Example before preprocessing:
- example["messages"] = [
- { "role": "system", "content": "You are an assistant that
- occasionally misspells words." },
- { "role": "user", "content": "Tell me a story." },
- { "role": "assistant", "content": "One day a student
- went to schoool." }
- ]
- Example after preprocessing:
- example["conversation"] = [
- {
- "system": "You are an assistant that occasionally misspells
- words.",
- "input": "Tell me a story.",
- "output": "One day a student went to schoool."
- }
- ]
- """
- messages = example['messages']
- system = ''
- input = ''
- conversation = []
- while messages and messages[0]['role'] == 'assistant':
- # Skip the first one if it is from assistant
- messages = messages[1:]
- for msg in messages:
- if msg['role'] == 'system':
- system = msg['content']
- elif msg['role'] == 'user':
- input += msg['content']
- elif msg['role'] == 'assistant':
- output_with_loss = msg.get('loss', 'True')
- output_with_loss = str(output_with_loss)
- output_with_loss = output_with_loss.lower() == 'true'
- conversation.append({
- 'system': system,
- 'input': input,
- 'output': msg['content'],
- 'output_with_loss': output_with_loss
- })
- system = ''
- input = ''
- else:
- raise NotImplementedError
- return {'conversation': conversation}
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/openorca_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/openorca_map_fn.py
deleted file mode 100644
index 45e58f3b9dd8e495c27050573eac4271eb7c746c..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/openorca_map_fn.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def openorca_map_fn(example):
- return {
- 'conversation': [{
- 'system': example['system_prompt'],
- 'input': example['question'],
- 'output': example['response']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/pretrain_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/pretrain_map_fn.py
deleted file mode 100644
index 861302ba8690074210ae8a751ba423075d10a240..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/pretrain_map_fn.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def pretrain_map_fn(example):
- r"""Example before preprocessing:
- example['text'] = 'xxx'
-
- Example after preprocessing:
- example['conversation'] = [
- {
- 'input': '',
- 'output': 'xxx'
- },
- ]
- """
- return {
- 'conversation': [{
- 'input': '',
- 'output': example['text'].strip(),
- 'need_eos_token': False
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/sql_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/sql_map_fn.py
deleted file mode 100644
index c83434f8de496a5a15f18c3038771070b0e4b608..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/sql_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def sql_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.sql,
- 'input': '{context}\n{question}'.format(**example),
- 'output': example['answer']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/stack_exchange_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/stack_exchange_map_fn.py
deleted file mode 100644
index 9fc3520e2919283133afb7ec26ff009469f38475..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/stack_exchange_map_fn.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def stack_exchange_map_fn(example):
- return {
- 'conversation': [{
- 'input': example['question'],
- 'output': example['response']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/tiny_codes_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/tiny_codes_map_fn.py
deleted file mode 100644
index fe0cc02b48c33ab3d9a0e717c293399f74cd6cfa..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/tiny_codes_map_fn.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.utils import SYSTEM_TEMPLATE
-
-
-def tiny_codes_map_fn(example):
- return {
- 'conversation': [{
- 'system': SYSTEM_TEMPLATE.coder,
- 'input': example['prompt'],
- 'output': example['response']
- }]
- }
diff --git a/code/xtuner/dataset/map_fns/dataset_map_fns/wizardlm_map_fn.py b/code/xtuner/dataset/map_fns/dataset_map_fns/wizardlm_map_fn.py
deleted file mode 100644
index 0174760d006b3efe2240671da672e2367076d30b..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/dataset_map_fns/wizardlm_map_fn.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-def wizardlm_map_fn(example):
- messages = example['conversations']
- input = ''
- conversation = []
- while messages and messages[0]['from'] == 'gpt':
- # Skip the first one if it is from gpt
- messages = messages[1:]
- for msg in messages:
- if msg['from'] == 'human':
- input += msg['value']
- elif msg['from'] == 'gpt':
- conversation.append({'input': input, 'output': msg['value']})
- input = ''
- else:
- raise NotImplementedError
- return {'conversation': conversation}
diff --git a/code/xtuner/dataset/map_fns/template_map_fn.py b/code/xtuner/dataset/map_fns/template_map_fn.py
deleted file mode 100644
index d7673b99efcdc2e1215303755401d68f570eedf2..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/map_fns/template_map_fn.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from functools import partial
-
-from mmengine.utils.misc import get_object_from_string
-
-
-def template_map_fn(example, template):
- conversation = example.get('conversation', [])
- for i, single_turn_conversation in enumerate(conversation):
- input = single_turn_conversation.get('input', '')
- if input is None:
- input = ''
- input_text = template.INSTRUCTION.format(input=input, round=i + 1)
- system = single_turn_conversation.get('system', '')
- if system != '' and system is not None:
- system = template.SYSTEM.format(system=system)
- input_text = system + input_text
- single_turn_conversation['input'] = input_text
-
- if template.get('SUFFIX', None):
- output_text = single_turn_conversation.get('output', '')
- output_text += template.SUFFIX
- single_turn_conversation['output'] = output_text
-
- # SUFFIX_AS_EOS is False ==> need_eos_token is True
- single_turn_conversation['need_eos_token'] = \
- not template.get('SUFFIX_AS_EOS', False)
- single_turn_conversation['sep'] = template.get('SEP', '')
-
- return {'conversation': conversation}
-
-
-def template_map_fn_factory(template):
- if isinstance(template, str): # for resume
- template = get_object_from_string(template)
- return partial(template_map_fn, template=template)
diff --git a/code/xtuner/dataset/modelscope.py b/code/xtuner/dataset/modelscope.py
deleted file mode 100644
index 9400050c34553dc8087a0f78e62918e47835d349..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/modelscope.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.config import Config, ConfigDict
-
-from xtuner.registry import BUILDER
-from .huggingface import process_hf_dataset
-
-
-def process_ms_dataset(dataset, split='train', *args, **kwargs):
- """Post-process the dataset loaded from the ModelScope Hub."""
-
- if isinstance(dataset, (Config, ConfigDict)):
- dataset = BUILDER.build(dataset)
- if isinstance(dataset, dict):
- dataset = dataset[split]
- dataset = dataset.to_hf_dataset()
- return process_hf_dataset(dataset, *args, **kwargs)
diff --git a/code/xtuner/dataset/moss_sft.py b/code/xtuner/dataset/moss_sft.py
deleted file mode 100644
index a5b7122bb700847dcab584e93b3ecc44c37404d3..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/moss_sft.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import copy
-import json
-import os
-
-import torch
-from mmengine.config import Config, ConfigDict
-from mmengine.logging import print_log
-from torch.utils.data import Dataset
-from tqdm import tqdm
-
-from xtuner.registry import BUILDER
-
-
-class MOSSSFTDataset(Dataset):
-
- def __init__(self, data_file, tokenizer, max_length=2048, bot_name=None):
- super().__init__()
- self.bot_name = bot_name
- self.src_data_file = data_file
- if isinstance(tokenizer, dict) or isinstance(
- tokenizer, Config) or isinstance(tokenizer, ConfigDict):
- self.tokenizer = BUILDER.build(tokenizer)
- else:
- self.tokenizer = tokenizer
- self.max_length = max_length
-
- self.data = []
- # We do not calculate losses for the meta instruction or results
- # returned by plugins
- # The token spans with label -100, [(span_start, span_end), ...]
- self.no_loss_spans = []
- self.labels = []
-
- self.pre = len(
- self.tokenizer.encode('<|Results|>:', add_special_tokens=False))
- self.post = len(
- self.tokenizer.encode('\n', add_special_tokens=False))
-
- self.load_data()
- self.process_data()
-
- def load_data(self):
- print_log('Loading MOSS SFT data...', 'current')
- name = f'{self.tokenizer.__class__.__name__}_{self.bot_name}'
- data_file = self.src_data_file.replace('.jsonl', f'_data_{name}')
- no_loss_spans_file = self.src_data_file.replace(
- '.jsonl', f'_no_loss_spans_{name}')
- if os.path.exists(data_file) and os.path.exists(no_loss_spans_file):
- self.data = torch.load(data_file, map_location='cpu')
- self.no_loss_spans = torch.load(
- no_loss_spans_file, map_location='cpu')
- else:
- with open(self.src_data_file) as f:
- for line in tqdm(f):
- sample = json.loads(line)
-
- chat = sample['chat']
- num_turns = int(sample['num_turns'])
-
- meta_instruction = sample['meta_instruction']
- if self.bot_name is not None:
- meta_instruction = meta_instruction.replace(
- 'MOSS', self.bot_name)
- instruction_ids = self.tokenizer.encode(meta_instruction)
- assert isinstance(instruction_ids,
- list) and len(instruction_ids) > 0
-
- input_ids = copy.deepcopy(instruction_ids)
- no_loss_spans = [(0, len(instruction_ids))]
- try:
- for i in range(num_turns):
- cur_turn_ids = []
- cur_no_loss_spans = []
- cur_turn = chat[f'turn_{i+1}']
- for key, value in cur_turn.items():
- if self.bot_name is not None:
- value = value.replace(
- 'MOSS', self.bot_name)
- cur_ids = self.tokenizer.encode(
- value, add_special_tokens=False)
- if key == 'Tool Responses':
- # The format tokens
- # (<|Results|>:...\n)
- # should have losses.
- cur_no_loss_spans.append(
- (len(input_ids + cur_turn_ids) +
- self.pre,
- len(input_ids + cur_turn_ids +
- cur_ids) - self.post))
-
- assert isinstance(cur_ids,
- list) and len(cur_ids) > 0
-
- cur_turn_ids.extend(cur_ids)
-
- if len(input_ids + cur_turn_ids) > self.max_length:
- break
-
- input_ids.extend(cur_turn_ids)
- no_loss_spans.extend(cur_no_loss_spans)
- if len(input_ids) == len(instruction_ids):
- continue
-
- assert len(input_ids) > 0 and len(
- input_ids) <= self.max_length
-
- self.data.append(input_ids)
- self.no_loss_spans.append(no_loss_spans)
- except Exception:
- pass
- torch.save(self.data, data_file)
- torch.save(self.no_loss_spans, no_loss_spans_file)
- print_log(
- f'Load data successfully, total {len(self.data)} training samples',
- 'current')
-
- def process_data(self):
- for item, no_loss in zip(self.data, self.no_loss_spans):
- label = copy.deepcopy(item)
- for loc in no_loss:
- label[loc[0]:loc[1]] = [-100] * (loc[1] - loc[0])
- self.labels.append(label)
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, index):
- return {'input_ids': self.data[index], 'labels': self.labels[index]}
diff --git a/code/xtuner/dataset/preference_dataset.py b/code/xtuner/dataset/preference_dataset.py
deleted file mode 100644
index 371ef829039742762ec7c725fb3a1acd4a57b420..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/preference_dataset.py
+++ /dev/null
@@ -1,386 +0,0 @@
-import copy
-import json
-import os
-from datetime import timedelta
-from functools import partial
-from multiprocessing import Process, Queue
-from typing import Callable, Dict, List
-
-import numpy as np
-import torch.distributed as dist
-import tqdm
-from datasets import Dataset as HFDataset
-from datasets import concatenate_datasets
-from mmengine.config import Config, ConfigDict
-from mmengine.logging import print_log
-from mmengine.utils.misc import get_object_from_string
-from torch.utils.data import Dataset
-from transformers import AutoTokenizer
-
-from xtuner.registry import BUILDER, MAP_FUNC
-from .huggingface import build_origin_dataset
-
-
-def _worker(
- tokenize_fun: Callable,
- data_queue: Queue,
- out_queue: Queue,
-):
- while True:
- data_chunk = data_queue.get()
-
- if data_chunk is None:
- out_queue.put(None)
- break
- chunk_results = []
- for idx, data in data_chunk:
- chunk_results.append([idx, tokenize_fun(data)])
- out_queue.put(chunk_results)
-
-
-def _chunk_data_to_queue(data_queue: Queue, data: List[Dict], chunk_size: int,
- nproc):
- data_iter = iter(data)
- chunk_data = []
- while True:
- try:
- item = next(data_iter)
- except StopIteration:
- break
- chunk_data.append(item)
- if len(chunk_data) == chunk_size:
- data_queue.put(chunk_data)
- chunk_data = []
- if chunk_data:
- data_queue.put(chunk_data)
-
- for _ in range(nproc):
- data_queue.put(None)
-
-
-def _multi_progress(tokenize_fun_p, dataset, nproc, task_num, chunksize,
- description):
- processes = []
- data_queue = Queue()
- output_queue = Queue()
- bar = tqdm.tqdm(total=task_num, desc=description)
- # task_id = bar.add_task(total=task_num, description=description)
- dataset = enumerate(dataset)
- _chunk_data_to_queue(data_queue, dataset, chunksize, nproc)
- for _ in range(nproc):
- process = Process(
- target=_worker, args=(tokenize_fun_p, data_queue, output_queue))
- process.start()
- processes.append(process)
-
- results = []
- finished_process = 0
- while finished_process < nproc:
- chunk_results = output_queue.get()
- if chunk_results is None:
- finished_process += 1
- continue
- results.extend(chunk_results)
- bar.update(len(chunk_results))
- bar.refresh()
- results = map(lambda x: x[1], sorted(results, key=lambda x: x[0]))
- return results
-
-
-def load_jsonl_dataset(data_files=None, data_dir=None, suffix=None):
- assert (data_files is not None) != (data_dir is not None)
- if data_dir is not None:
- data_files = os.listdir(data_dir)
- data_files = [os.path.join(data_dir, fn) for fn in data_files]
- if suffix is not None:
- data_files = [fp for fp in data_files if fp.endswith(suffix)]
- elif isinstance(data_files, str):
- data_files = [data_files]
-
- dataset_list = []
- for fp in data_files:
- with open(fp, encoding='utf-8') as file:
- data = [json.loads(line) for line in file]
- ds = HFDataset.from_list(data)
- dataset_list.append(ds)
- dataset = concatenate_datasets(dataset_list)
- return dataset
-
-
-def tokenize(pair: str,
- tokenizer: AutoTokenizer,
- max_length: int,
- is_reward: bool = False,
- reward_token_id: int = -1):
- prompt = tokenizer.apply_chat_template(
- pair['prompt'], tokenize=False, add_generation_prompt=True)
- chosen = tokenizer.apply_chat_template(
- pair['prompt'] + pair['chosen'],
- tokenize=False,
- add_generation_prompt=False)
- rejected = tokenizer.apply_chat_template(
- pair['prompt'] + pair['rejected'],
- tokenize=False,
- add_generation_prompt=False)
- prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
- chosen_ids = tokenizer.encode(chosen, add_special_tokens=False)
- rejected_ids = tokenizer.encode(rejected, add_special_tokens=False)
-
- if len(chosen_ids) > max_length:
- chosen_ids = chosen_ids[:max_length]
- if len(rejected_ids) > max_length:
- rejected_ids = rejected_ids[:max_length]
-
- if is_reward:
- # reward label
- chosen_ids = chosen_ids + [reward_token_id]
- rejected_ids = rejected_ids + [reward_token_id]
- chosen_labels = [-100] * len(chosen_ids[:-1]) + [0]
- rejected_labels = [-100] * len(rejected_ids[:-1]) + [1]
- else:
- # dpo label
- prompt_len = min(len(prompt_ids), max_length)
- chosen_labels = [-100] * prompt_len + copy.deepcopy(
- chosen_ids[prompt_len:])
- rejected_labels = [-100] * prompt_len + copy.deepcopy(
- rejected_ids[prompt_len:])
-
- return {
- 'chosen_ids': chosen_ids,
- 'rejected_ids': rejected_ids,
- 'chosen_labels': chosen_labels,
- 'rejected_labels': rejected_labels,
- }
-
-
-class PreferenceDataset(Dataset):
-
- def __init__(
- self,
- dataset: HFDataset,
- tokenizer: AutoTokenizer,
- max_length: int,
- is_dpo: bool = True,
- is_reward: bool = False,
- reward_token_id: int = -1,
- num_proc: int = 32,
- ) -> None:
- self.max_length = max_length
- assert is_dpo != is_reward, \
- 'Only one of is_dpo and is_reward can be True'
- if is_reward:
- assert reward_token_id != -1, \
- 'reward_token_id should be set if is_reward is True'
-
- self.is_dpo = is_dpo
- self.is_reward = is_reward
- self.reward_token_id = reward_token_id
- self.tokenized_pairs = []
-
- for tokenized_pair in _multi_progress(
- partial(
- tokenize,
- tokenizer=tokenizer,
- max_length=max_length,
- is_reward=is_reward,
- reward_token_id=reward_token_id),
- dataset,
- nproc=num_proc,
- task_num=len(dataset),
- chunksize=num_proc,
- description='Tokenizing dataset'):
- self.tokenized_pairs.append(tokenized_pair)
-
- def __len__(self):
- return len(self.tokenized_pairs)
-
- def __getitem__(self, idx):
- return self.tokenized_pairs[idx]
-
-
-class PackedDatasetWrapper(Dataset):
-
- def __init__(self,
- dataset,
- max_packed_length=16384,
- shuffle_before_pack=True) -> None:
- super().__init__()
- self.max_packed_length = max_packed_length
- self.lengths = []
- self.data = []
-
- indices = np.arange(len(dataset))
- if shuffle_before_pack:
- np.random.shuffle(indices)
-
- data_bin = []
- bin_seq_len = 0
- removed = 0
- for idx in indices:
- data = dataset[int(idx)]
- cur_len = len(data['chosen_ids']) + len(data['rejected_ids'])
- if cur_len > max_packed_length:
- print_log(
- f'sequence length {cur_len} is '
- f'larger than max_packed_length {max_packed_length}',
- logger='current')
- removed += 1
- continue
- if (bin_seq_len +
- cur_len) > max_packed_length and len(data_bin) > 0:
- self.data.append(data_bin)
- self.lengths.append(bin_seq_len)
- data_bin = []
- bin_seq_len = 0
- data_bin.append(data)
- bin_seq_len += cur_len
-
- if len(data_bin) > 0:
- self.data.append(data_bin)
- self.lengths.append(bin_seq_len)
- if removed > 0:
- print_log(
- f'removed {removed} samples because '
- f'of length larger than {max_packed_length}',
- logger='current')
- print_log(
- f'The batch numbers of dataset is changed '
- f'from {len(dataset)} to {len(self)} after'
- ' using var len attention.',
- logger='current')
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, index):
- pairs = self.data[index]
- input_ids, cu_seqlens, position_ids, labels = [], [0], [], []
-
- for pair in pairs:
- input_ids.extend(pair['chosen_ids'])
- input_ids.extend(pair['rejected_ids'])
-
- position_ids.extend(list(range(len(pair['chosen_ids']))))
- position_ids.extend(list(range(len(pair['rejected_ids']))))
-
- labels.extend(pair['chosen_labels'])
- labels.extend(pair['rejected_labels'])
-
- cu_seqlens.append(cu_seqlens[-1] + len(pair['chosen_ids']))
- cu_seqlens.append(cu_seqlens[-1] + len(pair['rejected_ids']))
-
- return {
- 'input_ids': input_ids,
- 'labels': labels,
- 'position_ids': position_ids,
- 'cumulative_len': cu_seqlens
- }
-
-
-def unpack_seq(seq, cu_seqlens):
- """Unpack a packed sequence to a list of sequences with different
- lengths."""
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- subseqs = seq.split(seqlens)
- return subseqs
-
-
-def broad_cast_dataset(dataset):
- xtuner_dataset_timeout = timedelta(
- minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=60)))
- print_log(
- f'xtuner_dataset_timeout = {xtuner_dataset_timeout}', logger='current')
- using_dist = dist.is_available() and dist.is_initialized()
- if using_dist:
- # monitored barrier requires gloo process group to perform host-side sync. # noqa
- group_gloo = dist.new_group(
- backend='gloo', timeout=xtuner_dataset_timeout)
- if not using_dist or dist.get_rank() == 0:
- objects = [dataset]
- else:
- objects = [None]
- if using_dist:
- dist.monitored_barrier(
- group=group_gloo, timeout=xtuner_dataset_timeout)
- dist.broadcast_object_list(objects, src=0)
- return objects[0]
-
-
-def map_dataset(dataset, dataset_map_fn, map_num_proc):
- if isinstance(dataset_map_fn, str):
- map_fn_obj = MAP_FUNC.get(dataset_map_fn) or get_object_from_string(
- dataset_map_fn)
- if map_fn_obj is not None:
- dataset_map_fn = map_fn_obj
- else:
- raise TypeError('dataset_map_fn must be a function or a '
- "registered function's string in MAP_FUNC, "
- f"but got a string of '{dataset_map_fn}'")
-
- dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc)
- return dataset
-
-
-def build_preference_dataset(
- dataset: str,
- tokenizer: AutoTokenizer,
- max_length: int,
- dataset_map_fn: Callable = None,
- is_dpo: bool = True,
- is_reward: bool = False,
- reward_token_id: int = -1,
- num_proc: int = 32,
- use_varlen_attn: bool = False,
- max_packed_length: int = 16384,
- shuffle_before_pack: bool = True,
-) -> Dataset:
- using_dist = dist.is_available() and dist.is_initialized()
- tokenized_ds = None
- if not using_dist or dist.get_rank() == 0:
- if isinstance(tokenizer, dict) or isinstance(
- tokenizer, Config) or isinstance(tokenizer, ConfigDict):
- tokenizer = BUILDER.build(tokenizer)
-
- dataset = build_origin_dataset(dataset, split='train')
- if dataset_map_fn is not None:
- dataset = map_dataset(
- dataset, dataset_map_fn, map_num_proc=num_proc)
-
- tokenized_ds = PreferenceDataset(
- dataset=dataset,
- tokenizer=tokenizer,
- max_length=max_length,
- is_dpo=is_dpo,
- is_reward=is_reward,
- reward_token_id=reward_token_id,
- num_proc=num_proc,
- )
- if use_varlen_attn:
- tokenized_ds = PackedDatasetWrapper(
- dataset=tokenized_ds,
- max_packed_length=max_packed_length,
- shuffle_before_pack=shuffle_before_pack,
- )
- tokenized_ds = broad_cast_dataset(tokenized_ds)
- return tokenized_ds
-
-
-def intel_orca_dpo_map_fn(example):
- prompt = [{
- 'role': 'system',
- 'content': example['system']
- }, {
- 'role': 'user',
- 'content': example['question']
- }]
- chosen = [{'role': 'assistant', 'content': example['chosen']}]
- rejected = [{'role': 'assistant', 'content': example['rejected']}]
- return {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
-
-
-def orpo_dpo_mix_40k_map_fn(example):
- assert len(example['chosen']) == len(example['rejected'])
- prompt = example['chosen'][:-1]
- chosen = example['chosen'][-1:]
- rejected = example['rejected'][-1:]
- return {'prompt': prompt, 'chosen': chosen, 'rejected': rejected}
diff --git a/code/xtuner/dataset/refcoco_json.py b/code/xtuner/dataset/refcoco_json.py
deleted file mode 100644
index e32f08ae459a21697e5a1736ad8a19bafaf767e5..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/refcoco_json.py
+++ /dev/null
@@ -1,496 +0,0 @@
-import copy
-import itertools
-import json
-import os
-import pickle
-import time
-from collections import defaultdict
-
-import matplotlib.pyplot as plt
-import numpy as np
-import skimage.io as io
-import torch
-from datasets import Dataset as HFDataset
-from datasets import DatasetDict
-from matplotlib.patches import Polygon, Rectangle
-from mmengine.config import Config, ConfigDict
-from PIL import Image
-
-from xtuner.registry import BUILDER
-from ..registry import BUILDER
-from .huggingface import process_hf_dataset
-from .llava import LLaVADataset
-from .utils import expand2square
-
-
-class RefCOCOJsonDataset(LLaVADataset):
- instruction_pool = [
- '[refer] {}',
- '[refer] give me the location of {}',
- '[refer] where is {} ?',
- '[refer] from this image, tell me the location of {}',
- '[refer] the location of {} is',
- '[refer] could you tell me the location for {} ?',
- '[refer] where can I locate the {} ?',
- ]
-
- def __init__(
- self,
- data_path,
- image_folder,
- tokenizer,
- image_processor,
- max_dataset_length=None,
- dataset_map_fn=None,
- template_map_fn=None,
- max_length=2048,
- pad_image_to_square=False,
- ):
- json_data = json.load(open(data_path))
-
- ######################################################
- # Only this part is different from LLaVADataset.__init__
- json_data = self.reformat_data(json_data)
- ######################################################
-
- for idx in range(len(json_data)):
- if isinstance(json_data[idx]['id'], int):
- json_data[idx]['id'] = str(json_data[idx]['id'])
- json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
- self.text_data = process_hf_dataset(
- dataset=json_data,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=dataset_map_fn,
- template_map_fn=template_map_fn,
- split='train',
- max_dataset_length=max_dataset_length,
- remove_unused_columns=False,
- pack_to_max_length=False,
- with_image_token=True)
-
- self.image_folder = image_folder
- if isinstance(image_processor, dict) or isinstance(
- image_processor, Config) or isinstance(image_processor,
- ConfigDict):
- self.image_processor = BUILDER.build(image_processor)
- else:
- self.image_processor = image_processor
- self.pad_image_to_square = pad_image_to_square
-
- def reformat_data(self, json_data):
- new_json_data = []
- for sample in json_data:
- for instruction_template in self.instruction_pool:
- sample['conversations'] = self.gen_refcoco_conversations(
- sample, instruction_template)
- new_json_data.append(copy.deepcopy(sample))
- return new_json_data
-
- @classmethod
- def gen_refcoco_conversations(cls, data, instruction_template='{}'):
- """build conversition data from refcoco json data as below.
-
- "id": "xxx",
- "image": "xxx.jpg",
- "conversations": [
- {
- "from": "human",
- "value": "xxxx"
- },
- {
- "from": "gpt",
- "value": "xxx"
- }
- """
-
- conversation = [
- {
- 'from': 'human',
- 'value': ''
- },
- {
- 'from': 'gpt',
- 'value': ''
- },
- ]
-
- instruction = instruction_template.format(data['sents'])
- bbox = cls.normalize_bbox(data['bbox'], data['height'], data['width'])
- answer = '{{<{}><{}><{}><{}>}}'.format(bbox[0], bbox[1], bbox[2],
- bbox[3])
- conversation[0]['value'] = instruction + '\n'
- conversation[1]['value'] = answer
- return conversation
-
- @classmethod
- def get_data_json(
- cls,
- ann_path,
- image_path,
- dataset='refcoco',
- splitBy='unc',
- ):
- refer = REFER(ann_path, image_path, dataset, splitBy)
- ref_ids = refer.getRefIds(split='train')
-
- data = {}
- duplicate_data = defaultdict(list)
-
- for ref_id in ref_ids:
- ref = refer.loadRefs(ref_id)[0]
-
- image_id = '{:0>12}'.format(ref['image_id'])
- sents = [sent['raw'] for sent in ref['sentences']]
- bbox = refer.getRefBox(ref['ref_id'])
-
- image = Image.open(image_path + '/' + image_id + '.jpg')
-
- for sent in sents:
- sent_id = '_'.join(sent.split(' '))
- data_id = f'{dataset}-{splitBy}-{image_id}-{sent_id}'
- data_item = {
- 'id': data_id,
- 'image': 'coco/train2017/' + image_id + '.jpg',
- 'sents': sent,
- 'bbox': bbox,
- 'height': image.height,
- 'width': image.width
- }
- if data_id in data:
- duplicate_data[data_id].append(data_item)
- else:
- data[data_id] = data_item
-
- return list(data.values()), list(duplicate_data.values())
-
- @classmethod
- def normalize_bbox(cls, bbox, height, width):
- x, y, w, h = bbox
-
- bbox = [x / width, y / height, (x + w) / width, (y + h) / height]
- bbox = [int(x * 100) for x in bbox]
- return bbox
-
-
-class RefCOCOJsonEvalDataset(RefCOCOJsonDataset):
- instruction_pool = ['[refer] give me the location of {}']
-
- def reformat_data(self, json_data):
- for sample in json_data:
- # reformat img_id
- img_id = sample['img_id'].split('_')[-2]
- sample['image'] = 'coco/train2017/' + img_id + '.jpg'
- sample['id'] = f"{img_id}-{sample['sents']}"
- return super().reformat_data(json_data)
-
-
-class InvRefCOCOJsonDataset(RefCOCOJsonDataset):
- instruction_pool = [
- '[identify] {}',
- '[identify] what object is in this location {}',
- '[identify] identify the object present at this location {}',
- '[identify] what is it in {}',
- '[identify] describe this object in {}',
- '[identify] this {} is',
- '[identify] the object in {} is',
- ]
-
- @classmethod
- def gen_refcoco_conversations(cls, data, instruction_template='{}'):
- """build conversition data from refcoco json data as below.
-
- "id": "xxx",
- "image": "xxx.jpg",
- "conversations": [
- {
- "from": "human",
- "value": "xxxx"
- },
- {
- "from": "gpt",
- "value": "xxx"
- }
- """
-
- conversation = [
- {
- 'from': 'human',
- 'value': ''
- },
- {
- 'from': 'gpt',
- 'value': ''
- },
- ]
- bbox = cls.normalize_bbox(data['bbox'], data['height'], data['width'])
- bbox_str = '{{<{}><{}><{}><{}>}}'.format(bbox[0], bbox[1], bbox[2],
- bbox[3])
- instruction = instruction_template.format(bbox_str)
- answer = data['sents']
-
- conversation[0]['value'] = instruction + '\n'
- conversation[1]['value'] = answer
- return conversation
-
-
-# flake8: noqa
-# Refer
-
-
-class REFER:
-
- def __init__(self, data_root, vis_root, dataset='refcoco', splitBy='unc'):
- # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
- # also provide dataset name and splitBy information
- # e.g., dataset = 'refcoco', splitBy = 'unc'
- # inv dataset is stored in the same path as normal dataset
- dataset = dataset.split('inv')[-1]
- print('loading dataset %s into memory...' % dataset)
- self.ann_dir = os.path.join(data_root, dataset)
- if dataset in ['refcoco', 'refcoco+', 'refcocog']:
- self.vis_root = vis_root
- elif dataset == 'refclef':
- raise 'No RefClef image data'
- else:
- raise 'No refer dataset is called [%s]' % dataset
-
- # load refs from data/dataset/refs(dataset).json
- tic = time.time()
- ref_file = os.path.join(self.ann_dir, 'refs(' + splitBy + ').p')
- self.data = {}
- self.data['dataset'] = dataset
- self.data['refs'] = pickle.load(open(ref_file, 'rb'))
-
- # load annotations from data/dataset/instances.json
- instances_file = os.path.join(self.ann_dir, 'instances.json')
- instances = json.load(open(instances_file))
- self.data['images'] = instances['images']
- self.data['annotations'] = instances['annotations']
- self.data['categories'] = instances['categories']
-
- # create index
- self.createIndex()
- print('DONE (t=%.2fs)' % (time.time() - tic))
-
- def createIndex(self):
- # create sets of mapping
- # 1) Refs: {ref_id: ref}
- # 2) Anns: {ann_id: ann}
- # 3) Imgs: {image_id: image}
- # 4) Cats: {category_id: category_name}
- # 5) Sents: {sent_id: sent}
- # 6) imgToRefs: {image_id: refs}
- # 7) imgToAnns: {image_id: anns}
- # 8) refToAnn: {ref_id: ann}
- # 9) annToRef: {ann_id: ref}
- # 10) catToRefs: {category_id: refs}
- # 11) sentToRef: {sent_id: ref}
- # 12) sentToTokens: {sent_id: tokens}
- print('creating index...')
- # fetch info from instances
- Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
- for ann in self.data['annotations']:
- Anns[ann['id']] = ann
- imgToAnns[ann['image_id']] = imgToAnns.get(ann['image_id'],
- []) + [ann]
- for img in self.data['images']:
- Imgs[img['id']] = img
- for cat in self.data['categories']:
- Cats[cat['id']] = cat['name']
-
- # fetch info from refs
- Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
- Sents, sentToRef, sentToTokens = {}, {}, {}
- for ref in self.data['refs']:
- # ids
- ref_id = ref['ref_id']
- ann_id = ref['ann_id']
- category_id = ref['category_id']
- image_id = ref['image_id']
-
- # add mapping related to ref
- Refs[ref_id] = ref
- imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
- catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
- refToAnn[ref_id] = Anns[ann_id]
- annToRef[ann_id] = ref
-
- # add mapping of sent
- for sent in ref['sentences']:
- Sents[sent['sent_id']] = sent
- sentToRef[sent['sent_id']] = ref
- sentToTokens[sent['sent_id']] = sent['tokens']
-
- # create class members
- self.Refs = Refs
- self.Anns = Anns
- self.Imgs = Imgs
- self.Cats = Cats
- self.Sents = Sents
- self.imgToRefs = imgToRefs
- self.imgToAnns = imgToAnns
- self.refToAnn = refToAnn
- self.annToRef = annToRef
- self.catToRefs = catToRefs
- self.sentToRef = sentToRef
- self.sentToTokens = sentToTokens
- print('index created.')
-
- def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=''):
- image_ids = image_ids if type(image_ids) == list else [image_ids]
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
-
- if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
- refs = self.data['refs']
- else:
- if not len(image_ids) == 0:
- refs = [self.imgToRefs[image_id] for image_id in image_ids]
- else:
- refs = self.data['refs']
- if not len(cat_ids) == 0:
- refs = [ref for ref in refs if ref['category_id'] in cat_ids]
- if not len(ref_ids) == 0:
- refs = [ref for ref in refs if ref['ref_id'] in ref_ids]
- if not len(split) == 0:
- if split in ['testA', 'testB', 'testC']:
- refs = [ref for ref in refs if split[-1] in ref['split']
- ] # we also consider testAB, testBC, ...
- elif split in ['testAB', 'testBC', 'testAC']:
- # rarely used I guess...
- refs = [ref for ref in refs if ref['split'] == split]
- elif split == 'test':
- refs = [ref for ref in refs if 'test' in ref['split']]
- elif split == 'train' or split == 'val':
- refs = [ref for ref in refs if ref['split'] == split]
- else:
- raise 'No such split [%s]' % split
- ref_ids = [ref['ref_id'] for ref in refs]
- return ref_ids
-
- def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
- image_ids = image_ids if type(image_ids) == list else [image_ids]
- cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
-
- if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
- ann_ids = [ann['id'] for ann in self.data['annotations']]
- else:
- if not len(image_ids) == 0:
- lists = [
- self.imgToAnns[image_id] for image_id in image_ids
- if image_id in self.imgToAnns
- ] # list of [anns]
- anns = list(itertools.chain.from_iterable(lists))
- else:
- anns = self.data['annotations']
- if not len(cat_ids) == 0:
- anns = [ann for ann in anns if ann['category_id'] in cat_ids]
- ann_ids = [ann['id'] for ann in anns]
- if not len(ref_ids) == 0:
- ids = set(ann_ids).intersection(
- {self.Refs[ref_id]['ann_id']
- for ref_id in ref_ids})
- return ann_ids
-
- def getImgIds(self, ref_ids=[]):
- ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
-
- if not len(ref_ids) == 0:
- image_ids = list(
- {self.Refs[ref_id]['image_id']
- for ref_id in ref_ids})
- else:
- image_ids = self.Imgs.keys()
- return image_ids
-
- def getCatIds(self):
- return self.Cats.keys()
-
- def loadRefs(self, ref_ids=[]):
- if type(ref_ids) == list:
- return [self.Refs[ref_id] for ref_id in ref_ids]
- elif type(ref_ids) == int:
- return [self.Refs[ref_ids]]
-
- def loadAnns(self, ann_ids=[]):
- if type(ann_ids) == list:
- return [self.Anns[ann_id] for ann_id in ann_ids]
- elif type(ann_ids) == int:
- return [self.Anns[ann_ids]]
-
- def loadImgs(self, image_ids=[]):
- if type(image_ids) == list:
- return [self.Imgs[image_id] for image_id in image_ids]
- elif type(image_ids) == int:
- return [self.Imgs[image_ids]]
-
- def loadCats(self, cat_ids=[]):
- if type(cat_ids) == list:
- return [self.Cats[cat_id] for cat_id in cat_ids]
- elif type(cat_ids) == int:
- return [self.Cats[cat_ids]]
-
- def getRefBox(self, ref_id):
- ref = self.Refs[ref_id]
- ann = self.refToAnn[ref_id]
- return ann['bbox'] # [x, y, w, h]
-
- def showRef(self, ref, seg_box='box'):
- from matplotlib.collectns import PatchCollection
-
- ax = plt.gca()
- # show image
- image = self.Imgs[ref['image_id']]
- I = io.imread(os.path.join(self.vis_root, image['file_name']))
- ax.imshow(I)
- # show refer expression
- for sid, sent in enumerate(ref['sentences']):
- print('{}. {}'.format(sid + 1, sent['sent']))
- # show segmentations
- if seg_box == 'seg':
- ann_id = ref['ann_id']
- ann = self.Anns[ann_id]
- polygons = []
- color = []
- c = 'none'
- if type(ann['segmentation'][0]) == list:
- # polygon used for refcoco*
- for seg in ann['segmentation']:
- poly = np.array(seg).reshape((len(seg) / 2, 2))
- polygons.append(Polygon(poly, True, alpha=0.4))
- color.append(c)
- p = PatchCollection(
- polygons,
- facecolors=color,
- edgecolors=(1, 1, 0, 0),
- linewidths=3,
- alpha=1,
- )
- ax.add_collection(p) # thick yellow polygon
- p = PatchCollection(
- polygons,
- facecolors=color,
- edgecolors=(1, 0, 0, 0),
- linewidths=1,
- alpha=1,
- )
- ax.add_collection(p) # thin red polygon
- else:
- # mask used for refclef
- raise NotImplementedError('RefClef is not downloaded')
- # show bounding-box
- elif seg_box == 'box':
- ann_id = ref['ann_id']
- ann = self.Anns[ann_id]
- bbox = self.getRefBox(ref['ref_id'])
- box_plot = Rectangle(
- (bbox[0], bbox[1]),
- bbox[2],
- bbox[3],
- fill=False,
- edgecolor='green',
- linewidth=3,
- )
- ax.add_patch(box_plot)
diff --git a/code/xtuner/dataset/samplers/__init__.py b/code/xtuner/dataset/samplers/__init__.py
deleted file mode 100644
index 8afc9bc1e2bbaae2e00a530302c24106400f2ace..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/samplers/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .intern_repo import InternlmRepoSampler, InternRepoSampler
-from .length_grouped import LengthGroupedSampler
-
-__all__ = ['LengthGroupedSampler', 'InternRepoSampler', 'InternlmRepoSampler']
diff --git a/code/xtuner/dataset/samplers/intern_repo.py b/code/xtuner/dataset/samplers/intern_repo.py
deleted file mode 100644
index 933719a58e5c8efa46d14bc5080bd7ed1e9b0ce4..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/samplers/intern_repo.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import logging
-import warnings
-from typing import Iterator, Optional, Sized
-
-import numpy as np
-from mmengine import print_log
-from torch.utils.data import Sampler
-
-from xtuner.parallel.sequence import (get_data_parallel_rank,
- get_data_parallel_world_size)
-
-
-class InternRepoSampler(Sampler):
-
- def __init__(self,
- dataset: Sized,
- shuffle: bool = True,
- seed: Optional[int] = None) -> None:
- if seed is not None and seed != 1024:
- warnings.warn('For alignment accuracy, seed in InternRepoSampler'
- 'must be set to 1024.')
- world_size = get_data_parallel_world_size()
- rank = get_data_parallel_rank()
- self.rank = rank
- self.world_size = world_size
-
- self.dataset = dataset
- self.shuffle = shuffle
- self.seed = 1024
- self.epoch = 0
-
- self.num_samples = len(self.dataset) // world_size
- self.total_size = self.num_samples * world_size
-
- def __iter__(self) -> Iterator[int]:
- """Iterate the indices."""
- # deterministically shuffle based on epoch and seed
- if self.shuffle:
- rng = np.random.RandomState(self.seed + self.epoch)
- indices = np.arange(len(self.dataset))
- rng.shuffle(indices)
- indices = indices.tolist()
- else:
- indices = np.arange(len(self.dataset)).tolist()
-
- self.indices = indices[:self.total_size]
-
- # subsample
- indices = indices[self.rank:self.total_size:self.world_size]
- self.subsample_indices = indices
-
- return iter(indices)
-
- def __len__(self) -> int:
- """The number of samples in this rank."""
- return self.num_samples
-
- def set_epoch(self, epoch: int) -> None:
- """Sets the epoch for this sampler.
-
- When :attr:`shuffle=True`, this ensures all replicas use a different
- random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
-
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
-
-
-class InternlmRepoSampler(InternRepoSampler):
-
- def __init__(self,
- dataset: Sized,
- shuffle: bool = True,
- seed: Optional[int] = None) -> None:
- super().__init__(dataset, shuffle, seed)
- print_log(('InternlmRepoSampler will be deprecated in the future.'
- 'Please use InternRepoSampler instead.'),
- logger='current',
- level=logging.WARNING)
diff --git a/code/xtuner/dataset/samplers/length_grouped.py b/code/xtuner/dataset/samplers/length_grouped.py
deleted file mode 100644
index 184827837cf062972d6b024940ba6d252577efd4..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/samplers/length_grouped.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-from typing import Iterator, Optional, Sized
-
-import torch
-from mmengine.dist import get_dist_info, sync_random_seed
-from mmengine.logging import print_log
-from torch.utils.data import ConcatDataset as TorchConcatDataset
-from torch.utils.data import Sampler
-
-
-def get_length_grouped_indices(lengths, group_batch_size, generator=None):
-
- def process(lengths, group_batch_size, generator=None):
- indices = torch.randperm(len(lengths), generator=generator)
- megabatches = [
- indices[i:i + group_batch_size].tolist()
- for i in range(0, len(lengths), group_batch_size)
- ]
- megabatches = [
- sorted(megabatch, key=lambda i: lengths[i], reverse=True)
- for megabatch in megabatches
- ]
- return megabatches
-
- assert all(leng != 0 for leng in lengths), 'Should not have zero length.'
- if all(leng > 0 for leng in lengths) or all(leng < 0 for leng in lengths):
- # all samples are in the same modality
- megabatches = process(lengths, group_batch_size, generator=generator)
- else:
- mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths)
- if l > 0])
- lang_indices, lang_lengths = zip(*[(i, -l)
- for i, l in enumerate(lengths)
- if l < 0])
- mm_megabatches = []
- for mm_megabatch in process(
- mm_lengths, group_batch_size, generator=generator):
- mm_megabatches.append([mm_indices[i] for i in mm_megabatch])
- lang_megabatches = []
- for lang_megabatch in process(
- lang_lengths, group_batch_size, generator=generator):
- lang_megabatches.append([lang_indices[i] for i in lang_megabatch])
-
- last_mm = mm_megabatches[-1]
- last_lang = lang_megabatches[-1]
- last_batch = last_mm + last_lang
- megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
-
- megabatch_indices = torch.randperm(
- len(megabatches), generator=generator)
- megabatches = [megabatches[i] for i in megabatch_indices]
-
- if len(last_batch) > 0:
- megabatches.append(
- sorted(
- last_batch, key=lambda i: abs(lengths[i]), reverse=True))
-
- # The rest is to get the biggest batch first.
- # Since each megabatch is sorted by descending length,
- # the longest element is the first
- megabatch_maximums = [
- abs(lengths[megabatch[0]]) for megabatch in megabatches
- ]
- max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
- # Switch to put the longest element in first position
- megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][
- 0], megabatches[0][0]
-
- return [i for megabatch in megabatches for i in megabatch]
-
-
-class LengthGroupedSampler(Sampler):
-
- def __init__(self,
- dataset: Sized,
- per_device_batch_size: int,
- length_property='length',
- mega_batch_mult: Optional[int] = None,
- seed: Optional[int] = None,
- round_up: bool = True) -> None:
- print_log('LengthGroupedSampler is used.', logger='current')
- rank, world_size = get_dist_info()
- self.rank = rank
- self.world_size = world_size
-
- self.dataset = dataset
- if seed is None:
- seed = sync_random_seed()
- self.seed = seed
- self.epoch = 0
- self.round_up = round_up
-
- if self.round_up:
- num_iters = math.ceil(
- len(self.dataset) / world_size / per_device_batch_size)
- self.num_samples = num_iters * per_device_batch_size
- self.total_size = self.num_samples * self.world_size
- else:
- self.num_samples = math.ceil(
- (len(self.dataset) - rank) / world_size)
- self.total_size = len(self.dataset)
-
- total_batch_size = per_device_batch_size * self.world_size
- if mega_batch_mult is None:
- # Default for mega_batch_mult: 50 or the number to get 4
- # megabatches, whichever is smaller.
- mega_batch_mult = min(
- len(self.dataset) // (total_batch_size * 4), 50)
- # Just in case, for tiny datasets
- if mega_batch_mult == 0:
- mega_batch_mult = 1
- self.group_batch_size = mega_batch_mult * total_batch_size
-
- if isinstance(self.dataset, TorchConcatDataset):
- length = []
- for sub_dataset in self.dataset.datasets:
- length.extend(getattr(sub_dataset, length_property))
- self.length = length
- else:
- self.length = getattr(self.dataset, length_property)
- assert isinstance(self.length, (list, tuple))
-
- self.total_batch_size = total_batch_size
- print_log(
- f'LengthGroupedSampler construction is complete, '
- f'and the selected attribute is {length_property}',
- logger='current')
-
- def __iter__(self) -> Iterator[int]:
- """Iterate the indices."""
- generator = torch.Generator()
- generator.manual_seed(self.seed + self.epoch)
- indices = get_length_grouped_indices(
- lengths=self.length,
- group_batch_size=self.group_batch_size,
- generator=generator)
- assert len(set(indices)) == len(indices)
- # add extra samples to make it evenly divisible
- if self.round_up:
- indices = (
- indices *
- int(self.total_size / len(indices) + 1))[:self.total_size]
- # subsample
- assert len(indices) == self.total_size
- indices = indices[self.rank:self.total_size:self.world_size]
- assert len(indices) == self.num_samples
- return iter(indices)
-
- def __len__(self) -> int:
- """The number of samples in this rank."""
- return self.num_samples
-
- def set_epoch(self, epoch: int) -> None:
- """Sets the epoch for this sampler.
-
- When :attr:`shuffle=True`, this ensures all replicas use a different
- random ordering for each epoch. Otherwise, the next iteration of this
- sampler will yield the same ordering.
-
- Args:
- epoch (int): Epoch number.
- """
- self.epoch = epoch
diff --git a/code/xtuner/dataset/utils.py b/code/xtuner/dataset/utils.py
deleted file mode 100644
index 2cfd225d3bcde5e90597f22c4a7d19aced36e83c..0000000000000000000000000000000000000000
--- a/code/xtuner/dataset/utils.py
+++ /dev/null
@@ -1,580 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import base64
-import copy
-import io
-from io import BytesIO
-from itertools import chain
-
-import numpy as np
-import requests
-from PIL import Image
-import h5py
-
-from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
-
-
-def get_bos_eos_token_ids(tokenizer):
- if tokenizer.__class__.__name__ in [
- 'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
- ]:
- bos_token_id = []
- eos_token_id = tokenizer.eos_token_id
- assert eos_token_id is not None, \
- 'Please set eos_token for Qwen tokenizer!'
- elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
- bos_token_id = [64790, 64792]
- eos_token_id = tokenizer.eos_token_id
- else:
- bos_token_id = tokenizer.bos_token_id
- eos_token_id = tokenizer.eos_token_id
- if isinstance(bos_token_id, int):
- bos_token_id = [bos_token_id]
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- return bos_token_id, eos_token_id
-
-
-def encode_fn(example,
- tokenizer,
- max_length,
- input_ids_with_output=True,
- with_image_token=False):
- """We only support the following three scenarios:
-
- 1. Incremental pretraining dataset.
- example['conversation'] = [
- {
- 'input': '',
- 'output': '### Human: Can you write xxx'
- }
- ]
-
- 2. Single-turn conversation dataset.
- example['conversation'] = [
- {
- 'input': 'Give three tips for staying healthy.',
- 'output': '1.Eat a balanced diet xxx'
- }
- ]
-
- 3. Multi-turn conversation dataset.
- example['conversation'] = [
- {
- 'input': 'Give three tips for staying healthy.',
- 'output': '1.Eat a balanced diet xxx'
- },
- {
- 'input': 'Please expand on the second point.',
- 'output': 'Here is an expanded explanation of the xxx'
- }
- ]
- """
- bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
- is_multi_turn_conversation = len(example['conversation']) > 1
- if is_multi_turn_conversation:
- assert input_ids_with_output
-
- input_ids, labels = [], []
- next_needs_bos_token = True
- for single_turn_conversation in example['conversation']:
- input = single_turn_conversation['input']
- if DEFAULT_IMAGE_TOKEN in input and with_image_token:
- chunk_encode = [
- tokenizer.encode(chunk, add_special_tokens=False)
- for chunk in input.split(DEFAULT_IMAGE_TOKEN)
- ]
- assert len(chunk_encode) == 2
- input_encode = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_encode.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_encode.append(IMAGE_TOKEN_INDEX)
- else:
- input_encode = tokenizer.encode(input, add_special_tokens=False)
- if next_needs_bos_token:
- input_ids += bos_token_id
- labels += [IGNORE_INDEX] * len(bos_token_id)
- input_ids += input_encode
- labels += [IGNORE_INDEX] * len(input_encode)
- if input_ids_with_output:
- # Add output
- output_with_loss = single_turn_conversation.get(
- 'output_with_loss', True)
- output = single_turn_conversation['output']
- output_encode = tokenizer.encode(output, add_special_tokens=False)
- input_ids += output_encode
- if output_with_loss:
- labels += copy.deepcopy(output_encode)
- else:
- labels += [IGNORE_INDEX] * len(output_encode)
- # Add EOS_TOKEN (with loss)
- if single_turn_conversation.get('need_eos_token', True):
- next_needs_bos_token = True
- input_ids += eos_token_id
- if output_with_loss:
- labels += copy.deepcopy(eos_token_id)
- else:
- labels += [IGNORE_INDEX] * len(eos_token_id)
- else:
- next_needs_bos_token = False
- # Add SEP (without loss)
- sep = single_turn_conversation.get('sep', '')
- if sep != '':
- sep_encode = tokenizer.encode(sep, add_special_tokens=False)
- input_ids += sep_encode
- labels += [IGNORE_INDEX] * len(sep_encode)
-
- if len(input_ids) > max_length:
- input_ids = input_ids[:max_length]
- labels = labels[:max_length]
- return {'input_ids': input_ids, 'labels': labels}
-
-
-class Packer:
- """Pack multiple pieces of data into one."""
-
- def __init__(self,
- chunk_size=2048,
- use_varlen_attn=False,
- drop_last=False):
- self.chunk_size = chunk_size
- self.residual = {'input_ids': [], 'labels': []}
- self.use_varlen_attn = use_varlen_attn
- self.drop_last = drop_last
- if use_varlen_attn:
- self.residual_cumulative_len = [0]
-
- def get_cumulative_len(self, chunk_num):
- ptr_l = 0
- cumulative_len = []
- for chunk_idx in range(chunk_num):
- length_train = (chunk_idx + 1) * self.chunk_size
- ptr_r = np.searchsorted(
- self.residual_cumulative_len, length_train, side='left')
- if self.residual_cumulative_len[ptr_r] == length_train:
- cumulative_len_cur = \
- self.residual_cumulative_len[ptr_l:ptr_r + 1]
- ptr_l = ptr_r + 1
- else:
- cumulative_len_cur = self.residual_cumulative_len[
- ptr_l:ptr_r] + [length_train]
- ptr_l = ptr_r
- cumulative_len_cur = [
- num - chunk_idx * self.chunk_size for num in cumulative_len_cur
- ]
- if cumulative_len_cur[0] != 0:
- cumulative_len_cur = [0] + cumulative_len_cur
-
- cumulative_len.append(cumulative_len_cur)
-
- self.residual_cumulative_len = [
- num - length_train for num in self.residual_cumulative_len[ptr_l:]
- ]
- if len(self.residual_cumulative_len) == 0:
- self.residual_cumulative_len = [0]
- elif self.residual_cumulative_len[0] != 0:
- self.residual_cumulative_len = [0] + self.residual_cumulative_len
-
- return cumulative_len
-
- def get_position_ids(self, cumulative_len):
- position_ids = []
- for cumulative_len_cur in cumulative_len:
- index_cur = []
- for i in range(len(cumulative_len_cur) - 1):
- index_cur.extend(
- list(
- range(cumulative_len_cur[i + 1] - # noqa: W504
- cumulative_len_cur[i])))
- position_ids.append(index_cur)
- return position_ids
-
- def __call__(self, batch):
- concatenated_samples = {
- k: v + list(chain(*batch[k]))
- for k, v in self.residual.items()
- }
-
- if self.use_varlen_attn:
- for input_id in batch['input_ids']:
- self.residual_cumulative_len.append(
- self.residual_cumulative_len[-1] + len(input_id))
-
- total_length = len(concatenated_samples[list(
- concatenated_samples.keys())[0]])
-
- if total_length >= self.chunk_size:
- chunk_num = total_length // self.chunk_size
- result = {
- k: [
- v[i:i + self.chunk_size] for i in range(
- 0,
- chunk_num * # noqa: W504
- self.chunk_size,
- self.chunk_size)
- ]
- for k, v in concatenated_samples.items()
- }
- self.residual = {
- k: v[(chunk_num * self.chunk_size):]
- for k, v in concatenated_samples.items()
- }
-
- if self.use_varlen_attn:
- cumulative_len = self.get_cumulative_len(chunk_num)
- result['cumulative_len'] = cumulative_len
- result['position_ids'] = self.get_position_ids(cumulative_len)
- else:
- if self.drop_last:
- result = {k: [] for k, v in concatenated_samples.items()}
- else:
- result = {k: [v] for k, v in concatenated_samples.items()}
-
- self.residual = {k: [] for k in concatenated_samples.keys()}
-
- if self.use_varlen_attn:
- result['cumulative_len'] = [] if self.drop_last else [
- self.residual_cumulative_len
- ]
- result['position_ids'] = [] if self.drop_last \
- else self.get_position_ids([self.residual_cumulative_len])
- self.residual_cumulative_len = [0]
-
- return result
-
-
-def expand2square(pil_img, background_color):
- width, height = pil_img.size
- if width == height:
- return pil_img
- elif width > height:
- result = Image.new(pil_img.mode, (width, width), background_color)
- result.paste(pil_img, (0, (width - height) // 2))
- return result
- else:
- result = Image.new(pil_img.mode, (height, height), background_color)
- result.paste(pil_img, ((height - width) // 2, 0))
- return result
-
-
-def load_image(image_file):
- if image_file.startswith('http://') or image_file.startswith('https://'):
- response = requests.get(image_file)
- image = Image.open(BytesIO(response.content)).convert('RGB')
- else:
- image = Image.open(image_file).convert('RGB')
- return image
-
-
-def decode_base64_to_image(base64_string):
- image_data = base64.b64decode(base64_string)
- image = Image.open(io.BytesIO(image_data))
- return image
-# Copyright (c) OpenMMLab. All rights reserved.
-import base64
-import copy
-import io
-from io import BytesIO
-from itertools import chain
-
-import numpy as np
-import requests
-from PIL import Image
-
-import pandas as pd
-
-from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
-
-
-def get_bos_eos_token_ids(tokenizer):
- if tokenizer.__class__.__name__ in [
- 'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
- ]:
- bos_token_id = []
- eos_token_id = tokenizer.eos_token_id
- assert eos_token_id is not None, \
- 'Please set eos_token for Qwen tokenizer!'
- elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
- bos_token_id = [64790, 64792]
- eos_token_id = tokenizer.eos_token_id
- else:
- bos_token_id = tokenizer.bos_token_id
- eos_token_id = tokenizer.eos_token_id
- if isinstance(bos_token_id, int):
- bos_token_id = [bos_token_id]
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- return bos_token_id, eos_token_id
-
-
-def encode_fn(example,
- tokenizer,
- max_length,
- input_ids_with_output=True,
- with_image_token=False,
- per_image_length=0):
- """We only support the following three scenarios:
-
- 1. Incremental pretraining dataset.
- example['conversation'] = [
- {
- 'input': '',
- 'output': '### Human: Can you write xxx'
- }
- ]
-
- 2. Single-turn conversation dataset.
- example['conversation'] = [
- {
- 'input': 'Give three tips for staying healthy.',
- 'output': '1.Eat a balanced diet xxx'
- }
- ]
-
- 3. Multi-turn conversation dataset.
- example['conversation'] = [
- {
- 'input': 'Give three tips for staying healthy.',
- 'output': '1.Eat a balanced diet xxx'
- },
- {
- 'input': 'Please expand on the second point.',
- 'output': 'Here is an expanded explanation of the xxx'
- }
- ]
- """
- bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
- is_multi_turn_conversation = len(example['conversation']) > 1
- if is_multi_turn_conversation:
- assert input_ids_with_output
-
- input_ids, labels = [], []
- n_images = 0
- next_needs_bos_token = True
- for single_turn_conversation in example['conversation']:
- input = single_turn_conversation['input']
- if DEFAULT_IMAGE_TOKEN in input and with_image_token:
- chunk_encode = [
- tokenizer.encode(chunk, add_special_tokens=False)
- for chunk in input.split(DEFAULT_IMAGE_TOKEN)
- ]
- # assert len(chunk_encode) == 2
- n_images += len(chunk_encode) - 1
- input_encode = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_encode.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_encode.append(IMAGE_TOKEN_INDEX)
- else:
- input_encode = tokenizer.encode(input, add_special_tokens=False)
- if next_needs_bos_token:
- input_ids += bos_token_id
- labels += [IGNORE_INDEX] * len(bos_token_id)
- input_ids += input_encode
- labels += [IGNORE_INDEX] * len(input_encode)
- if input_ids_with_output:
- # Add output
- output_with_loss = single_turn_conversation.get(
- 'output_with_loss', True)
- output = single_turn_conversation['output']
- output_encode = tokenizer.encode(output, add_special_tokens=False)
- input_ids += output_encode
- if output_with_loss:
- labels += copy.deepcopy(output_encode)
- else:
- labels += [IGNORE_INDEX] * len(output_encode)
- # Add EOS_TOKEN (with loss)
- if single_turn_conversation.get('need_eos_token', True):
- next_needs_bos_token = True
- input_ids += eos_token_id
- if output_with_loss:
- labels += copy.deepcopy(eos_token_id)
- else:
- labels += [IGNORE_INDEX] * len(eos_token_id)
- else:
- next_needs_bos_token = False
- # Add SEP (without loss)
- sep = single_turn_conversation.get('sep', '')
- if sep != '':
- sep_encode = tokenizer.encode(sep, add_special_tokens=False)
- input_ids += sep_encode
- labels += [IGNORE_INDEX] * len(sep_encode)
-
- # if len(input_ids) > max_length:
- # input_ids = input_ids[:max_length]
- # labels = labels[:max_length]
- input_ids = input_ids[:max_length - n_images * per_image_length]
- labels = labels[:max_length - n_images * per_image_length]
- return {'input_ids': input_ids, 'labels': labels}
-
-
-class Packer:
- """Pack multiple pieces of data into one."""
-
- def __init__(self,
- chunk_size=2048,
- use_varlen_attn=False,
- drop_last=False):
- self.chunk_size = chunk_size
- self.residual = {'input_ids': [], 'labels': []}
- self.use_varlen_attn = use_varlen_attn
- self.drop_last = drop_last
- if use_varlen_attn:
- self.residual_cumulative_len = [0]
-
- def get_cumulative_len(self, chunk_num):
- ptr_l = 0
- cumulative_len = []
- for chunk_idx in range(chunk_num):
- length_train = (chunk_idx + 1) * self.chunk_size
- ptr_r = np.searchsorted(
- self.residual_cumulative_len, length_train, side='left')
- if self.residual_cumulative_len[ptr_r] == length_train:
- cumulative_len_cur = \
- self.residual_cumulative_len[ptr_l:ptr_r + 1]
- ptr_l = ptr_r + 1
- else:
- cumulative_len_cur = self.residual_cumulative_len[
- ptr_l:ptr_r] + [length_train]
- ptr_l = ptr_r
- cumulative_len_cur = [
- num - chunk_idx * self.chunk_size for num in cumulative_len_cur
- ]
- if cumulative_len_cur[0] != 0:
- cumulative_len_cur = [0] + cumulative_len_cur
-
- cumulative_len.append(cumulative_len_cur)
-
- self.residual_cumulative_len = [
- num - length_train for num in self.residual_cumulative_len[ptr_l:]
- ]
- if len(self.residual_cumulative_len) == 0:
- self.residual_cumulative_len = [0]
- elif self.residual_cumulative_len[0] != 0:
- self.residual_cumulative_len = [0] + self.residual_cumulative_len
-
- return cumulative_len
-
- def get_position_ids(self, cumulative_len):
- position_ids = []
- for cumulative_len_cur in cumulative_len:
- index_cur = []
- for i in range(len(cumulative_len_cur) - 1):
- index_cur.extend(
- list(
- range(cumulative_len_cur[i + 1] - # noqa: W504
- cumulative_len_cur[i])))
- position_ids.append(index_cur)
- return position_ids
-
- def __call__(self, batch):
- concatenated_samples = {
- k: v + list(chain(*batch[k]))
- for k, v in self.residual.items()
- }
-
- if self.use_varlen_attn:
- for input_id in batch['input_ids']:
- self.residual_cumulative_len.append(
- self.residual_cumulative_len[-1] + len(input_id))
-
- total_length = len(concatenated_samples[list(
- concatenated_samples.keys())[0]])
-
- if total_length >= self.chunk_size:
- chunk_num = total_length // self.chunk_size
- result = {
- k: [
- v[i:i + self.chunk_size] for i in range(
- 0,
- chunk_num * # noqa: W504
- self.chunk_size,
- self.chunk_size)
- ]
- for k, v in concatenated_samples.items()
- }
- self.residual = {
- k: v[(chunk_num * self.chunk_size):]
- for k, v in concatenated_samples.items()
- }
-
- if self.use_varlen_attn:
- cumulative_len = self.get_cumulative_len(chunk_num)
- result['cumulative_len'] = cumulative_len
- result['position_ids'] = self.get_position_ids(cumulative_len)
- else:
- if self.drop_last:
- result = {k: [] for k, v in concatenated_samples.items()}
- else:
- result = {k: [v] for k, v in concatenated_samples.items()}
-
- self.residual = {k: [] for k in concatenated_samples.keys()}
-
- if self.use_varlen_attn:
- result['cumulative_len'] = [] if self.drop_last else [
- self.residual_cumulative_len
- ]
- result['position_ids'] = [] if self.drop_last \
- else self.get_position_ids([self.residual_cumulative_len])
- self.residual_cumulative_len = [0]
-
- return result
-
-
-def expand2square(pil_img, background_color):
- width, height = pil_img.size
- if width == height:
- return pil_img
- elif width > height:
- result = Image.new(pil_img.mode, (width, width), background_color)
- result.paste(pil_img, (0, (width - height) // 2))
- return result
- else:
- result = Image.new(pil_img.mode, (height, height), background_color)
- result.paste(pil_img, ((height - width) // 2, 0))
- return result
-
-
-def load_image(image_file):
- if image_file.startswith('http://') or image_file.startswith('https://'):
- response = requests.get(image_file)
- image = Image.open(BytesIO(response.content)).convert('RGB')
- elif image_file.endswith('.csv'):
- image = pd.read_csv(image_file)
- image = image.iloc[:, :512]
-
- total_rows = image.shape[0]
- if total_rows >= 10240:
- indices = np.linspace(0, total_rows - 1, 10240, dtype=int)
- sampled_df = image.iloc[indices]
- image = sampled_df.iloc[:10240]
-
-
- image = image.to_numpy().reshape(1, image.shape[0], 512)
- elif image_file.endswith('.h5'):
- with h5py.File(image_file, 'r') as f:
- feats_np = f['features'][:]
- coords_np = f['coords'][:]
-
- total_rows = feats_np.shape[0]
- if total_rows != coords_np.shape[0]:
- raise ValueError(f"Mismatch rows in features ({total_rows}) vs coords ({coords_np.shape[0]}) for {image_file}")
-
- if total_rows >= 10240:
- indices = np.linspace(0, total_rows - 1, 10240, dtype=int)
- feats_np = feats_np[indices]
- coords_np = coords_np[indices]
- return feats_np.reshape(1, feats_np.shape[0], 512), coords_np.reshape(1, coords_np.shape[0], 2)
- else:
- image = Image.open(image_file).convert('RGB')
- return image
-
-def load_wsi_feature(wsi_file):
- wsi_feats = pd.read_csv(wsi_file)
- wsi_feats = wsi_feats.to_numpy()
- return wsi_feats
-
-def decode_base64_to_image(base64_string):
- image_data = base64.b64decode(base64_string)
- image = Image.open(io.BytesIO(image_data))
- return image
diff --git a/code/xtuner/engine/__init__.py b/code/xtuner/engine/__init__.py
deleted file mode 100644
index 9b1d44cc1cf92c43882d24edf6e02d23096e1282..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from ._strategy import DeepSpeedStrategy
-from .hooks import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
- VarlenAttnArgsToMessageHubHook)
-from .runner import TrainLoop
-from .optimizer import MuonOptimWrapperConstructor
-
-__all__ = [
- 'EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook',
- 'VarlenAttnArgsToMessageHubHook', 'DeepSpeedStrategy', 'TrainLoop',
- 'MuonOptimWrapperConstructor'
-]
diff --git a/code/xtuner/engine/__pycache__/__init__.cpython-311.pyc b/code/xtuner/engine/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 52ef8c71f7fb9b310d8008d381aed533a23b8c1b..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/_strategy/__init__.py b/code/xtuner/engine/_strategy/__init__.py
deleted file mode 100644
index bac6095f977fa39655deb1d95c67d2e641e274b4..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/_strategy/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .deepspeed import DeepSpeedStrategy
-
-__all__ = ['DeepSpeedStrategy']
diff --git a/code/xtuner/engine/_strategy/__pycache__/__init__.cpython-311.pyc b/code/xtuner/engine/_strategy/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 8f9390bf13944ce9320202b166e9c15fcad77598..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/_strategy/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/_strategy/__pycache__/deepspeed.cpython-311.pyc b/code/xtuner/engine/_strategy/__pycache__/deepspeed.cpython-311.pyc
deleted file mode 100644
index 5f7890938675c81d2e275cfb539f2c45c120dce8..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/_strategy/__pycache__/deepspeed.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/_strategy/deepspeed.py b/code/xtuner/engine/_strategy/deepspeed.py
deleted file mode 100644
index ee4797149107339e2acc09c360a5d6931ea67a64..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/_strategy/deepspeed.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional
-
-from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy
-
-from xtuner import DS_CEPH_DIR
-from xtuner.parallel.sequence import init_sequence_parallel
-from xtuner.utils.fileio import patch_fileio
-
-
-
-def count_parameters(model):
- total_params = sum(p.numel() for p in model.parameters())
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-
- print(f"Total parameters: {total_params:,}")
- print(f"Trainable parameters: {trainable_params:,}")
- return total_params, trainable_params
-
-
-class DeepSpeedStrategy(MMEngineDeepSpeedStrategy):
-
- def __init__(self, *args, **kwargs):
- sequence_parallel_size = kwargs.pop('sequence_parallel_size', 1)
- self.sequence_parallel_size = sequence_parallel_size
-
- super().__init__(*args, **kwargs)
-
- from transformers.integrations.deepspeed import HfDeepSpeedConfig
-
- # hf_deepspeed_config has to be saved as an attribute.
- self.hf_deepspeed_config = HfDeepSpeedConfig(self.config)
-
- def _wrap_model(self, model):
- count_parameters(model)
- wrapper = super()._wrap_model(model)
- # hard code for deepspeed zero3
- # When utilizing Zero3, the model isn't allocated to CUDA within the
- # `deepspeed.initialize` process.
- assert hasattr(wrapper.model, 'data_preprocessor')
- wrapper.model.data_preprocessor.cuda()
- return wrapper
-
- def save_checkpoint(self, *args, **kwargs) -> None:
- if DS_CEPH_DIR:
- from os import path as osp
- work_dir_prefix = osp.split(self.work_dir)[0]
-
- filename = kwargs['filename'].replace(work_dir_prefix, DS_CEPH_DIR)
- kwargs['filename'] = filename
- with patch_fileio():
- super().save_checkpoint(*args, **kwargs)
- else:
- super().save_checkpoint(*args, **kwargs)
-
- def load_checkpoint(self, *args, **kwargs) -> None:
- if DS_CEPH_DIR:
-
- with patch_fileio():
- checkpoint = super().load_checkpoint(*args, **kwargs)
- else:
- checkpoint = super().load_checkpoint(*args, **kwargs)
- return checkpoint
-
- def resume(self, *args, **kwargs) -> None:
- if DS_CEPH_DIR:
-
- with patch_fileio():
- checkpoint = super().resume(*args, **kwargs)
- else:
- checkpoint = super().resume(*args, **kwargs)
- return checkpoint
-
- def _setup_distributed( # type: ignore
- self,
- launcher: Optional[str] = None,
- backend: str = 'nccl',
- **kwargs,
- ):
- super()._setup_distributed(launcher, backend, **kwargs)
- init_sequence_parallel(self.sequence_parallel_size)
diff --git a/code/xtuner/engine/hooks/__init__.py b/code/xtuner/engine/hooks/__init__.py
deleted file mode 100644
index a19df0ab6c1b53c2cd1655f720f0febb501689f5..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .dataset_info_hook import DatasetInfoHook
-from .evaluate_chat_hook import EvaluateChatHook
-from .evaluate_chat_hook_resampler import EvaluateChatHook as EvaluateChatHookResampler
-from .hf_checkpoint_hook import HFCheckpointHook
-from .throughput_hook import ThroughputHook
-from .varlen_attn_args_to_messagehub_hook import VarlenAttnArgsToMessageHubHook
-from .learning_rate_hook import PrintInitLRHook
-from .visual_warmup_book import TwoPhaseVisualWarmupHook
-
-__all__ = [
- 'EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook',
- 'VarlenAttnArgsToMessageHubHook', 'HFCheckpointHook',
- "PrintInitLRHook",
- "EvaluateChatHookResampler",
- "TwoPhaseVisualWarmupHook"
-]
diff --git a/code/xtuner/engine/hooks/__pycache__/__init__.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 2a843d7c5a5eda731c077844fdc4c4d72b6b28f3..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/dataset_info_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/dataset_info_hook.cpython-311.pyc
deleted file mode 100644
index 83cb129f13de3d82b5a5090681a35a04d214e19c..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/dataset_info_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook.cpython-311.pyc
deleted file mode 100644
index 44e29065325867eaedcb919019deb3836f081ca8..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook_resampler.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook_resampler.cpython-311.pyc
deleted file mode 100644
index 7c069ff1af881acff25f59929b7aefda039c7d1a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/evaluate_chat_hook_resampler.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/hf_checkpoint_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/hf_checkpoint_hook.cpython-311.pyc
deleted file mode 100644
index 57f323921d7e0bb8027671b6ea03132c9242c0f4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/hf_checkpoint_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/learning_rate_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/learning_rate_hook.cpython-311.pyc
deleted file mode 100644
index 0325de85985948e74d76854253f776112123e744..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/learning_rate_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/throughput_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/throughput_hook.cpython-311.pyc
deleted file mode 100644
index c7542797160d6107834220ca834e4aed858c6ee9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/throughput_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/varlen_attn_args_to_messagehub_hook.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/varlen_attn_args_to_messagehub_hook.cpython-311.pyc
deleted file mode 100644
index 8c38f567a1e7bc05ccf2174a094f3f2706f6bcc0..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/varlen_attn_args_to_messagehub_hook.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/__pycache__/visual_warmup_book.cpython-311.pyc b/code/xtuner/engine/hooks/__pycache__/visual_warmup_book.cpython-311.pyc
deleted file mode 100644
index 0dc426ef3fd72a60a42b993f9f29b1ce497b1e77..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/hooks/__pycache__/visual_warmup_book.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/hooks/dataset_info_hook.py b/code/xtuner/engine/hooks/dataset_info_hook.py
deleted file mode 100644
index 84dc9498a4ce0aa2cc8175c9e317e1a35ca13fc9..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/dataset_info_hook.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.hooks import Hook
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
-
-
-def split_list(lst, value):
- res = []
- tmp_res = []
- for i in lst:
- if i == value:
- res.append(tmp_res)
- tmp_res = []
- else:
- tmp_res.append(i)
- res.append(tmp_res)
- return res
-
-
-class DatasetInfoHook(Hook):
-
- def __init__(self, tokenizer, is_intern_repo_dataset=False):
- self.tokenizer = BUILDER.build(tokenizer)
- self.is_intern_repo_dataset = is_intern_repo_dataset
-
- def log(self, runner, dataset, mode='train'):
-
- def _log(input_ids, log_prefix=''):
- if self.is_intern_repo_dataset:
- input_ids = [abs(x) for x in input_ids]
- # Try to split list to be compatible with IMAGE token
- input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX)
- text = log_prefix
- for idx, ids in enumerate(input_ids):
- text += self.tokenizer.decode(ids)
- if idx != len(input_ids) - 1:
- text += DEFAULT_IMAGE_TOKEN
- runner.logger.info(text)
-
- runner.logger.info(f'Num {mode} samples {len(dataset)}')
- runner.logger.info(f'{mode} example:')
- if 'chosen_ids' in dataset[0]:
- _log(dataset[0]['chosen_ids'], log_prefix='chosen: ')
- _log(dataset[0]['rejected_ids'], log_prefix='rejected: ')
- else:
- _log(dataset[0]['input_ids'])
-
- def before_train(self, runner) -> None:
- do_train = runner.train_loop is not None
- do_eval = runner.val_loop is not None
- if do_train:
- train_dataset = runner.train_dataloader.dataset
- self.log(runner, train_dataset, mode='train')
- if do_eval:
- eval_dataset = runner.val_dataloader.dataset
- self.log(runner, eval_dataset, mode='eval')
-
- def before_val(self, runner) -> None:
- eval_dataset = runner.val_dataloader.dataset
- self.log(runner, eval_dataset, mode='eval')
-
- def before_test(self, runner) -> None:
- test_dataset = runner.test_dataloader.dataset
- self.log(runner, test_dataset, mode='test')
diff --git a/code/xtuner/engine/hooks/evaluate_chat_hook.py b/code/xtuner/engine/hooks/evaluate_chat_hook.py
deleted file mode 100644
index 8c573abf1e842cd5a349359bebb715500e5aef5b..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/evaluate_chat_hook.py
+++ /dev/null
@@ -1,371 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import warnings
-
-import torch
-from mmengine.dist import master_only
-from mmengine.hooks import Hook
-from mmengine.model import is_model_wrapper
-from mmengine.utils import mkdir_or_exist
-from mmengine.utils.misc import get_object_from_string
-from transformers import GenerationConfig, StoppingCriteriaList
-
-from xtuner.dataset.utils import load_image
-from xtuner.model.utils import prepare_inputs_labels_for_multimodal
-from xtuner.registry import BUILDER
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-
-
-class EvaluateChatHook(Hook):
- """
- Eval hook updated for LLaVAModel:
- - Accepts features (and optional coords) from `load_image`.
- - Applies optional SparsePatchMerging (token_merge) using coords.
- - Projects features with `model.projector`.
- - Optionally compresses with Perceiver resampler (if enabled), with
- tile-position PE injection via `coords_to_pos`.
- - Prepares multimodal inputs and runs generation on `model.llm`.
-
- Coordinates (if present) must be [L,2] or [B,L,2] with last dim=2.
- """
-
- priority = 'LOW'
-
- def __init__(self,
- tokenizer,
- evaluation_inputs,
- evaluation_images=None,
- image_processor=None,
- system='',
- prompt_template=None,
- every_n_iters=None,
- max_new_tokens=600,
- stop_word=None,
- stop_words=[],
- generation_kwargs={}):
- self.evaluation_inputs = evaluation_inputs
- if isinstance(self.evaluation_inputs, str):
- self.evaluation_inputs = [self.evaluation_inputs]
-
- # Accept paths to images/features; normalize length to inputs length
- self.evaluation_images = evaluation_images
- if isinstance(self.evaluation_images, str):
- self.evaluation_images = [self.evaluation_images]
- if self.evaluation_images is not None:
- assert len(self.evaluation_images) in [1, len(self.evaluation_inputs)]
- if len(self.evaluation_images) == 1:
- self.evaluation_images = [self.evaluation_images[0]] * len(self.evaluation_inputs)
-
- # Load features and optional coords
- self.eval_feats = []
- self.eval_coords = []
- for img in self.evaluation_images:
- loaded = load_image(img)
- if isinstance(loaded, tuple):
- feats, coords = loaded
- else:
- feats, coords = loaded, None
- self.eval_feats.append(feats)
- self.eval_coords.append(coords)
- else:
- self.eval_feats, self.eval_coords = None, None
-
- # Prompt templating and stopwords
- if prompt_template is None:
- instruction = '{input}'
- else:
- if isinstance(prompt_template, str): # for resume
- prompt_template = get_object_from_string(prompt_template)
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- if system != '':
- system = prompt_template.get('SYSTEM', '{system}\n').format(system=system)
- stop_words += prompt_template.get('STOP_WORDS', [])
- if stop_word is not None:
- warnings.warn(
- ('The `stop_word` argument is deprecated and will be removed '
- 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning)
- stop_words.append(stop_word)
- self.instruction = instruction
- self.system = system
- self.every_n_iters = every_n_iters
- self.max_new_tokens = max_new_tokens
- self.tokenizer = BUILDER.build(tokenizer)
- if image_processor is not None:
- self.image_processor = BUILDER.build(image_processor)
-
- # default generation config
- default_generation_kwargs = dict(
- max_new_tokens=max_new_tokens,
- do_sample=True,
- temperature=0.1,
- top_p=0.75,
- top_k=40,
- eos_token_id=self.tokenizer.eos_token_id,
- pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
- )
- default_generation_kwargs.update(generation_kwargs)
- self.gen_config = GenerationConfig(**default_generation_kwargs)
-
- self.stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- self.stop_criteria.append(StopWordStoppingCriteria(self.tokenizer, word))
-
- self.is_first_run = True
-
- @master_only
- def _save_eval_output(self, runner, eval_outputs):
- save_path = os.path.join(runner.log_dir, 'vis_data', f'eval_outputs_iter_{runner.iter}.txt')
- mkdir_or_exist(os.path.dirname(save_path))
- with open(save_path, 'w', encoding='utf-8') as f:
- for i, output in enumerate(eval_outputs):
- f.write(f'Eval output {i + 1}:\n{output}\n\n')
-
- @torch.no_grad()
- def _eval_images(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- """
- New image/feature eval path that mirrors LLaVAModel.forward():
- - move feats/coords to device
- - optional token_merge with coords
- - projector → (optional) PE injection → perceiver
- - prepare_inputs_labels_for_multimodal and generate on model.llm
- """
- if save_eval_output:
- eval_outputs = []
-
- llm = model.llm # convenience
- dtype = llm.dtype
-
- for sample_feats, sample_coords, sample_input in zip(self.eval_feats, self.eval_coords, self.evaluation_inputs):
- # Build chat prompt with a single token
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (self.system + self.instruction).format(input=sample_input, round=1, **runner.cfg)
-
- # Manually place IMAGE_TOKEN_INDEX into the input_ids
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = self.tokenizer.encode(chunk)
- else:
- cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) # [1, T]
-
- # ---- Move features / coords to device and validate shapes ----
- # feats: numpy (1, N, D) or torch [1, N, D]
- feats = torch.from_numpy(sample_feats).to(device) if not torch.is_tensor(sample_feats) else sample_feats.to(device)
- feats = feats.to(dtype) # LLaVAModel expects llm dtype
-
- coords_t = None
- if sample_coords is not None:
- coords_t = torch.from_numpy(sample_coords).to(device) if not torch.is_tensor(sample_coords) else sample_coords.to(device)
- # Accept [L,2] or [B,L,2]; enforce last dim = 2
- if coords_t.dim() == 2:
- coords_rc = coords_t # [L,2]
- elif coords_t.dim() == 3:
- Bx = feats.size(0)
- if coords_t.size(0) != Bx:
- raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
- if Bx == 1:
- coords_rc = coords_t[0]
- else:
- if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
- raise NotImplementedError(
- "Per-example coords (varying across batch) are not supported by the current "
- "patch-merging/layout path. Use batch size 1 or share coords across the batch."
- )
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
- else:
- # LLaVAModel forward requires coords when visual tokens are present.
- raise RuntimeError("WSI evaluation requires coordinates for token merge and PE/Perceiver.")
-
- # ---- Optional token-merge using row/col coords (keeps batch=1) ----
- # Expect model._coords_to_rowcol and model.token_merge when enabled
- image_tokens = feats # start from raw patch features [1, N, D]
- if getattr(model, "enable_token_merge", False) and hasattr(model, "token_merge"):
- if not hasattr(model, "_coords_to_rowcol"):
- raise AttributeError("Model is missing `_coords_to_rowcol` required for token_merge.")
- image_tokens, coords_after, _ = model.token_merge(
- x=image_tokens,
- coords_rc=model._coords_to_rowcol(coords_rc),
- padmask=torch.zeros([image_tokens.size(0), image_tokens.size(1)], device=image_tokens.device).bool()
- )
- # keep coords for PE; coords_after is typically aligned to merged tokens
- coords_rc = coords_after
-
- # ---- Project into LLM hidden space ----
- pixel_values = model.projector(image_tokens.to(dtype)) # [1, S', H_llm_in]
-
- # ---- Optional Perceiver resampler with position encodings ----
- if getattr(model, "use_perceiver_resampler", False) and hasattr(model, "perceiver"):
- # text embeddings from LLM’s token embed; clamp(min=0) to avoid -1 for special tokens
- text_embeds = llm.get_input_embeddings()(input_ids.clamp(min=0)).to(dtype).detach()
-
- # coordinates → 1D positions; inject PE as in model.forward
- if not hasattr(model, "coords_to_pos"):
- raise AttributeError("Model is missing `coords_to_pos` required for PE injection.")
- pos = model.coords_to_pos(coords_rc, getattr(model, "tile_size", 224)) # [N, L]
- # Ensure pos has batch dimension
- if pos.dim() == 1:
- pos = pos.unsqueeze(0)
-
- # pos_embed: [1, num_patches, H]; gather then scale/drop
- if not hasattr(model, "pos_embed") or not hasattr(model, "pe_gate") or not hasattr(model, "pe_drop"):
- raise AttributeError("Model PE buffers (pos_embed/pe_gate/pe_drop) not found.")
- # Gather PEs for current positions and inject
- pe = model.pos_embed[:, pos, :].squeeze(0) # [B, L, H]
- pixel_values = pixel_values + model.pe_drop(pe * model.pe_gate)
-
- # Run perceiver to compress visual tokens conditioned on text (attention_mask optional)
- pixel_values = model.perceiver(
- text_embeddings=text_embeds,
- attention_mask=None,
- visual_tokens=pixel_values,
- )
-
- runner.logger.info(
- f'evaluate feats: {feats.shape}, coords: {coords_t.shape if sample_coords is not None else None}, '
- f'pixel_values(after proj{" + perceiver" if getattr(model,"use_perceiver_resampler",False) else ""}): {pixel_values.shape}'
- )
-
- # ---- Pack for generation and decode ----
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=llm, input_ids=input_ids, pixel_values=pixel_values)
-
- # IMPORTANT: drive generate on the LLM (model.llm), not the wrapper
- generation_output = llm.generate(
- **mm_inputs,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- bos_token_id=self.tokenizer.bos_token_id,
- stopping_criteria=self.stop_criteria
- )
- # generation_output: [1, T_full] — decode only the full string (like before)
- text = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n{inputs + text}\n')
-
- if save_eval_output:
- eval_outputs.append(f'{inputs + text}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- @torch.no_grad()
- def _eval_language(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- if save_eval_output:
- eval_outputs = []
-
- llm = model.llm
- for sample_input in self.evaluation_inputs:
- inputs = (self.system + self.instruction).format(input=sample_input, round=1, **runner.cfg)
- input_ids = self.tokenizer.encode(inputs, return_tensors='pt').to(device)
- generation_output = llm.generate(
- input_ids=input_ids,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- stopping_criteria=self.stop_criteria)
- text = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n{text}\n')
- if save_eval_output:
- eval_outputs.append(f'{text}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- @torch.no_grad()
- def _generate_samples(self,
- runner,
- max_new_tokens=None,
- save_eval_output=False):
- if max_new_tokens is None:
- max_new_tokens = self.max_new_tokens
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
-
- device = next(iter(model.parameters())).device
-
- if self.is_first_run:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to device
- model.to(device)
- self.is_first_run = False
-
- is_checkpointing = model.llm.is_gradient_checkpointing
- use_cache = model.llm.config.use_cache
-
- # Cast to inference mode
- model.activation_checkpointing_disable()
- model.llm.config.use_cache = True
- model.eval()
- if self.evaluation_images is not None:
- self._eval_images(runner, model, device, max_new_tokens, save_eval_output)
- else:
- self._eval_language(runner, model, device, max_new_tokens, save_eval_output)
-
- # Cast back to training mode
- if is_checkpointing:
- model.activation_checkpointing_enable()
- model.llm.config.use_cache = use_cache
- model.train()
-
- def before_train(self, runner):
- runner.logger.info('before_train in EvaluateChatHook.')
- self._generate_samples(runner, max_new_tokens=50)
-
- def _is_save_checkpoint(self, runner):
- hooks = runner.hooks
- checkpoint_hook = None
- for hook in hooks:
- if type(hook).__name__ == 'CheckpointHook':
- checkpoint_hook = hook
- break
- if checkpoint_hook is None or checkpoint_hook.by_epoch:
- return False
-
- if checkpoint_hook.every_n_train_iters(
- runner, checkpoint_hook.interval, checkpoint_hook.save_begin) or \
- (checkpoint_hook.save_last and checkpoint_hook.is_last_train_iter(runner)):
- return True
- return False
-
- def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None:
- if self.every_n_iters is None:
- return
-
- save_eval_output = self._is_save_checkpoint(runner)
- do_chat = (save_eval_output or self.every_n_train_iters(runner, self.every_n_iters))
- if not do_chat:
- return
-
- runner.logger.info('after_train_iter in EvaluateChatHook.')
- self._generate_samples(runner, save_eval_output=save_eval_output)
-
- def after_train(self, runner):
- runner.logger.info('after_train in EvaluateChatHook.')
- self._generate_samples(runner)
-
- def after_val(self, runner) -> None:
- if self.every_n_iters is not None:
- return
- runner.logger.info('after_val in EvaluateChatHook.')
- self._generate_samples(runner)
\ No newline at end of file
diff --git a/code/xtuner/engine/hooks/evaluate_chat_hook_raw.py b/code/xtuner/engine/hooks/evaluate_chat_hook_raw.py
deleted file mode 100644
index 05d508e4c8f232a9299c1d1b7f69cfbc18262dbc..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/evaluate_chat_hook_raw.py
+++ /dev/null
@@ -1,281 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import warnings
-
-import torch
-from mmengine.dist import master_only
-from mmengine.hooks import Hook
-from mmengine.model import is_model_wrapper
-from mmengine.utils import mkdir_or_exist
-from mmengine.utils.misc import get_object_from_string
-from transformers import GenerationConfig, StoppingCriteriaList
-
-from xtuner.dataset.utils import expand2square, load_image
-from xtuner.model.utils import prepare_inputs_labels_for_multimodal
-from xtuner.registry import BUILDER
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-
-
-class EvaluateChatHook(Hook):
-
- priority = 'LOW'
-
- def __init__(self,
- tokenizer,
- evaluation_inputs,
- evaluation_images=None,
- image_processor=None,
- system='',
- prompt_template=None,
- every_n_iters=None,
- max_new_tokens=600,
- stop_word=None,
- stop_words=[],
- generation_kwargs={}):
- self.evaluation_inputs = evaluation_inputs
- if isinstance(self.evaluation_inputs, str):
- self.evaluation_inputs = [self.evaluation_inputs]
- self.evaluation_images = evaluation_images
- if isinstance(self.evaluation_images, str):
- self.evaluation_images = [self.evaluation_images]
- if self.evaluation_images is not None:
- assert len(
- self.evaluation_images) in [1, len(self.evaluation_inputs)]
- if len(self.evaluation_images) == 1:
- self.evaluation_images = [self.evaluation_images[0]] * len(
- self.evaluation_inputs)
- self.evaluation_images = [
- load_image(img) for img in self.evaluation_images
- ]
- if prompt_template is None:
- instruction = '{input}'
- else:
- if isinstance(prompt_template, str): # for resume
- prompt_template = get_object_from_string(prompt_template)
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- if system != '':
- system = prompt_template.get(
- 'SYSTEM', '{system}\n').format(system=system)
- stop_words += prompt_template.get('STOP_WORDS', [])
- if stop_word is not None:
- # TODO: deprecation, v0.3.0
- warnings.warn(
- ('The `stop_word` argument is deprecated and will be removed '
- 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning)
- stop_words.append(stop_word)
- self.instruction = instruction
- self.system = system
- self.every_n_iters = every_n_iters
- self.max_new_tokens = max_new_tokens
- self.tokenizer = BUILDER.build(tokenizer)
- if image_processor is not None:
- self.image_processor = BUILDER.build(image_processor)
- self.stop_criteria = StoppingCriteriaList()
-
- # default generation config
- default_generation_kwargs = dict(
- max_new_tokens=max_new_tokens,
- do_sample=True,
- temperature=0.1,
- top_p=0.75,
- top_k=40,
- eos_token_id=self.tokenizer.eos_token_id,
- pad_token_id=self.tokenizer.pad_token_id
- if self.tokenizer.pad_token_id is not None else
- self.tokenizer.eos_token_id)
- default_generation_kwargs.update(generation_kwargs)
- self.gen_config = GenerationConfig(**default_generation_kwargs)
-
- self.stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- self.stop_criteria.append(
- StopWordStoppingCriteria(self.tokenizer, word))
-
- self.is_first_run = True
-
- @master_only
- def _save_eval_output(self, runner, eval_outputs):
- save_path = os.path.join(runner.log_dir, 'vis_data',
- f'eval_outputs_iter_{runner.iter}.txt')
- mkdir_or_exist(os.path.dirname(save_path))
- with open(save_path, 'w', encoding='utf-8') as f:
- for i, output in enumerate(eval_outputs):
- f.write(f'Eval output {i + 1}:\n{output}\n\n')
-
- def _eval_images(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- if save_eval_output:
- eval_outputs = []
-
- for sample_image, sample_input in zip(self.evaluation_images,
- self.evaluation_inputs):
- image = expand2square(
- sample_image,
- tuple(int(x * 255) for x in self.image_processor.image_mean))
- image = self.image_processor.preprocess(
- image, return_tensors='pt')['pixel_values'][0]
- image = image.to(device)
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (self.system + self.instruction).format(
- input=sample_input, round=1, **runner.cfg)
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = self.tokenizer.encode(chunk)
- else:
- cur_encode = self.tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids).to(device)
- visual_outputs = model.visual_encoder(
- image.unsqueeze(0).to(model.visual_encoder.dtype),
- output_hidden_states=True)
- pixel_values = model.projector(
- visual_outputs.hidden_states[model.visual_select_layer][:, 1:])
-
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=model.llm,
- input_ids=input_ids.unsqueeze(0),
- pixel_values=pixel_values)
-
- generation_output = model.generate(
- **mm_inputs,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- bos_token_id=self.tokenizer.bos_token_id,
- stopping_criteria=self.stop_criteria)
- generation_output = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n'
- f'{inputs + generation_output}\n')
- if save_eval_output:
- eval_outputs.append(f'{inputs + generation_output}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- def _eval_language(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- if save_eval_output:
- eval_outputs = []
-
- for sample_input in self.evaluation_inputs:
- inputs = (self.system + self.instruction).format(
- input=sample_input, round=1, **runner.cfg)
- input_ids = self.tokenizer.encode(inputs, return_tensors='pt')
- input_ids = input_ids.to(device)
- generation_output = model.generate(
- input_ids=input_ids,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- stopping_criteria=self.stop_criteria)
- generation_output = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n{generation_output}\n')
- if save_eval_output:
- eval_outputs.append(f'{generation_output}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- def _generate_samples(self,
- runner,
- max_new_tokens=None,
- save_eval_output=False):
- if max_new_tokens is None:
- max_new_tokens = self.max_new_tokens
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
-
- device = next(iter(model.parameters())).device
-
- if self.is_first_run:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- model.to(device)
- self.is_first_run = False
-
- is_checkpointing = model.llm.is_gradient_checkpointing
- use_cache = model.llm.config.use_cache
-
- # Cast to inference mode
- model.activation_checkpointing_disable()
- model.llm.config.use_cache = True
- model.eval()
- if self.evaluation_images is not None:
- self._eval_images(runner, model, device, max_new_tokens,
- save_eval_output)
- else:
- self._eval_language(runner, model, device, max_new_tokens,
- save_eval_output)
-
- # Cast to training mode
- if is_checkpointing:
- model.activation_checkpointing_enable()
- model.llm.config.use_cache = use_cache
- model.train()
-
- def before_train(self, runner):
- runner.logger.info('before_train in EvaluateChatHook.')
- self._generate_samples(runner, max_new_tokens=50)
-
- def _is_save_checkpoint(self, runner):
- hooks = runner.hooks
- checkpoint_hook = None
- for hook in hooks:
- if type(hook).__name__ == 'CheckpointHook':
- checkpoint_hook = hook
- break
- if checkpoint_hook is None or checkpoint_hook.by_epoch:
- return False
-
- if checkpoint_hook.every_n_train_iters(
- runner, checkpoint_hook.interval, checkpoint_hook.save_begin) or \
- (checkpoint_hook.save_last and
- checkpoint_hook.is_last_train_iter(runner)):
- return True
-
- return False
-
- def after_train_iter(self,
- runner,
- batch_idx: int,
- data_batch=None,
- outputs=None) -> None:
- if self.every_n_iters is None:
- return
-
- save_eval_output = self._is_save_checkpoint(runner)
-
- do_chat = (
- save_eval_output
- or self.every_n_train_iters(runner, self.every_n_iters))
- if not do_chat:
- return
-
- runner.logger.info('after_train_iter in EvaluateChatHook.')
- self._generate_samples(runner, save_eval_output=save_eval_output)
-
- def after_train(self, runner):
- runner.logger.info('after_train in EvaluateChatHook.')
- self._generate_samples(runner)
-
- def after_val(self, runner) -> None:
- if self.every_n_iters is not None:
- return
- runner.logger.info('after_val in EvaluateChatHook.')
- self._generate_samples(runner)
diff --git a/code/xtuner/engine/hooks/evaluate_chat_hook_resampler.py b/code/xtuner/engine/hooks/evaluate_chat_hook_resampler.py
deleted file mode 100644
index a0b17dc02d71789ac15de136e8e8317fa04aa103..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/evaluate_chat_hook_resampler.py
+++ /dev/null
@@ -1,337 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import warnings
-
-import torch
-from mmengine.dist import master_only
-from mmengine.hooks import Hook
-from mmengine.model import is_model_wrapper
-from mmengine.utils import mkdir_or_exist
-from mmengine.utils.misc import get_object_from_string
-from transformers import GenerationConfig, StoppingCriteriaList
-
-from xtuner.dataset.utils import load_image
-from xtuner.model.utils import prepare_inputs_labels_for_multimodal
-from xtuner.registry import BUILDER
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-
-
-class EvaluateChatHook(Hook):
- """
- Eval hook updated for LLaVAModel with the new Resampler:
- - Accepts features (and optional coords) from `load_image`.
- - Applies optional SparsePatchMerging (token_merge) using coords.
- - Projects features with `model.projector`.
- - Optionally compresses with the new Resampler, which takes projected
- features and their coordinates (coords_rc) as input.
- - Prepares multimodal inputs and runs generation on `model.llm`.
-
- Coordinates (if present) must be [L,2] or [B,L,2] with last dim=2.
- """
-
- priority = 'LOW'
-
- def __init__(self,
- tokenizer,
- evaluation_inputs,
- evaluation_images=None,
- image_processor=None,
- system='',
- prompt_template=None,
- every_n_iters=None,
- max_new_tokens=600,
- stop_word=None,
- stop_words=[],
- generation_kwargs={}):
- self.evaluation_inputs = evaluation_inputs
- if isinstance(self.evaluation_inputs, str):
- self.evaluation_inputs = [self.evaluation_inputs]
-
- # Accept paths to images/features; normalize length to inputs length
- self.evaluation_images = evaluation_images
- if isinstance(self.evaluation_images, str):
- self.evaluation_images = [self.evaluation_images]
- if self.evaluation_images is not None:
- assert len(self.evaluation_images) in [1, len(self.evaluation_inputs)]
- if len(self.evaluation_images) == 1:
- self.evaluation_images = [self.evaluation_images[0]] * len(self.evaluation_inputs)
-
- # Load features and optional coords
- self.eval_feats = []
- self.eval_coords = []
- for img in self.evaluation_images:
- loaded = load_image(img)
- if isinstance(loaded, tuple):
- feats, coords = loaded
- else:
- feats, coords = loaded, None
- self.eval_feats.append(feats)
- self.eval_coords.append(coords)
- else:
- self.eval_feats, self.eval_coords = None, None
-
- # Prompt templating and stopwords
- if prompt_template is None:
- instruction = '{input}'
- else:
- if isinstance(prompt_template, str): # for resume
- prompt_template = get_object_from_string(prompt_template)
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- if system != '':
- system = prompt_template.get('SYSTEM', '{system}\n').format(system=system)
- stop_words += prompt_template.get('STOP_WORDS', [])
- if stop_word is not None:
- warnings.warn(
- ('The `stop_word` argument is deprecated and will be removed '
- 'in v0.3.0, use `stop_words` instead.'), DeprecationWarning)
- stop_words.append(stop_word)
- self.instruction = instruction
- self.system = system
- self.every_n_iters = every_n_iters
- self.max_new_tokens = max_new_tokens
- self.tokenizer = BUILDER.build(tokenizer)
- if image_processor is not None:
- self.image_processor = BUILDER.build(image_processor)
-
- # default generation config
- default_generation_kwargs = dict(
- max_new_tokens=max_new_tokens,
- do_sample=True,
- temperature=0.1,
- top_p=0.75,
- top_k=40,
- eos_token_id=self.tokenizer.eos_token_id,
- pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
- )
- default_generation_kwargs.update(generation_kwargs)
- self.gen_config = GenerationConfig(**default_generation_kwargs)
-
- self.stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- self.stop_criteria.append(StopWordStoppingCriteria(self.tokenizer, word))
-
- self.is_first_run = True
-
- @master_only
- def _save_eval_output(self, runner, eval_outputs):
- save_path = os.path.join(runner.log_dir, 'vis_data', f'eval_outputs_iter_{runner.iter}.txt')
- mkdir_or_exist(os.path.dirname(save_path))
- with open(save_path, 'w', encoding='utf-8') as f:
- for i, output in enumerate(eval_outputs):
- f.write(f'Eval output {i + 1}:\n{output}\n\n')
-
- @torch.no_grad()
- def _eval_images(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- """
- New image/feature eval path that mirrors LLaVAModel.forward():
- - move feats/coords to device
- - optional token_merge with coords
- - projector -> (optional) resampler
- - prepare_inputs_labels_for_multimodal and generate on model.llm
- """
- if save_eval_output:
- eval_outputs = []
-
- llm = model.llm # convenience
- dtype = llm.dtype
-
- for sample_feats, sample_coords, sample_input in zip(self.eval_feats, self.eval_coords, self.evaluation_inputs):
- # Build chat prompt with a single token
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (self.system + self.instruction).format(input=sample_input, round=1, **runner.cfg)
-
- # Manually place IMAGE_TOKEN_INDEX into the input_ids
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = self.tokenizer.encode(chunk)
- else:
- cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) # [1, T]
-
- # ---- Move features / coords to device and validate shapes ----
- feats = torch.from_numpy(sample_feats).to(device) if not torch.is_tensor(sample_feats) else sample_feats.to(device)
- feats = feats.to(dtype)
-
- coords_rc = None # Initialize coords_rc
- if sample_coords is not None:
- coords_t = torch.from_numpy(sample_coords).to(device) if not torch.is_tensor(sample_coords) else sample_coords.to(device)
- if coords_t.dim() == 2:
- coords_rc = coords_t
- elif coords_t.dim() == 3:
- if coords_t.size(0) != 1:
- raise NotImplementedError("Evaluation hook currently only supports a batch size of 1 for visual features.")
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
- else:
- raise RuntimeError("Visual evaluation requires coordinates for the Resampler.")
-
- # ---- Optional token-merge using row/col coords ----
- image_tokens = feats
- if getattr(model, "enable_token_merge", False) and hasattr(model, "token_merge"):
- if not hasattr(model, "_coords_to_rowcol"):
- raise AttributeError("Model is missing `_coords_to_rowcol` required for token_merge.")
- print(f'Before token_merge, image_tokens: {image_tokens.shape}, coords_rc: {coords_rc.shape if coords_rc is not None else None}')
- image_tokens, coords_rc, _ = model.token_merge(
- x=image_tokens,
- coords_rc=model._coords_to_rowcol(coords_rc),
- padmask=torch.zeros([image_tokens.size(0), image_tokens.size(1)], device=image_tokens.device).bool()
- )
-
- # ---- Project into LLM hidden space ----
- pixel_values = model.projector(image_tokens.to(dtype))
-
- # ---- Optional Resampler ----
- if getattr(model, "use_resampler", False) and hasattr(model, "resampler"):
- if coords_rc is None:
- raise RuntimeError("The Resampler requires coordinates (coords_rc) but they were not found.")
- if coords_rc.dtype != torch.long:
- coords_rc = coords_rc.long()
- pixel_values = model.resampler(pixel_values, coords_rc)
-
- log_msg_extra = " + resampler" if getattr(model, "use_resampler", False) else ""
- runner.logger.info(
- f'evaluate feats: {feats.shape}, coords: {coords_t.shape if sample_coords is not None else None}, '
- f'pixel_values(after proj{log_msg_extra}): {pixel_values.shape}'
- )
-
- # ---- Pack for generation and decode ----
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=llm, input_ids=input_ids, pixel_values=pixel_values)
-
- generation_output = llm.generate(
- **mm_inputs,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- bos_token_id=self.tokenizer.bos_token_id,
- stopping_criteria=self.stop_criteria
- )
- text = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n{inputs + text}\n')
-
- if save_eval_output:
- eval_outputs.append(f'{inputs + text}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- @torch.no_grad()
- def _eval_language(self,
- runner,
- model,
- device,
- max_new_tokens=None,
- save_eval_output=False):
- if save_eval_output:
- eval_outputs = []
-
- llm = model.llm
- for sample_input in self.evaluation_inputs:
- inputs = (self.system + self.instruction).format(input=sample_input, round=1, **runner.cfg)
- input_ids = self.tokenizer.encode(inputs, return_tensors='pt').to(device)
- generation_output = llm.generate(
- input_ids=input_ids,
- max_new_tokens=max_new_tokens,
- generation_config=self.gen_config,
- stopping_criteria=self.stop_criteria)
- text = self.tokenizer.decode(generation_output[0])
- runner.logger.info(f'Sample output:\n{text}\n')
- if save_eval_output:
- eval_outputs.append(f'{text}\n')
-
- if save_eval_output:
- self._save_eval_output(runner, eval_outputs)
-
- @torch.no_grad()
- def _generate_samples(self,
- runner,
- max_new_tokens=None,
- save_eval_output=False):
- if max_new_tokens is None:
- max_new_tokens = self.max_new_tokens
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
-
- device = next(iter(model.parameters())).device
-
- if self.is_first_run:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to device
- model.to(device)
- self.is_first_run = False
-
- is_checkpointing = model.llm.is_gradient_checkpointing
- use_cache = model.llm.config.use_cache
-
- # Cast to inference mode
- model.activation_checkpointing_disable()
- model.llm.config.use_cache = True
- model.eval()
- if self.evaluation_images is not None:
- self._eval_images(runner, model, device, max_new_tokens, save_eval_output)
- else:
- self._eval_language(runner, model, device, max_new_tokens, save_eval_output)
-
- # Cast back to training mode
- if is_checkpointing:
- model.activation_checkpointing_enable()
- model.llm.config.use_cache = use_cache
- model.train()
-
- def before_train(self, runner):
- runner.logger.info('before_train in EvaluateChatHook.')
- self._generate_samples(runner, max_new_tokens=50)
-
- def _is_save_checkpoint(self, runner):
- hooks = runner.hooks
- checkpoint_hook = None
- for hook in hooks:
- if type(hook).__name__ == 'CheckpointHook':
- checkpoint_hook = hook
- break
- if checkpoint_hook is None or checkpoint_hook.by_epoch:
- return False
-
- if checkpoint_hook.every_n_train_iters(
- runner, checkpoint_hook.interval, checkpoint_hook.save_begin) or \
- (checkpoint_hook.save_last and checkpoint_hook.is_last_train_iter(runner)):
- return True
- return False
-
- def after_train_iter(self, runner, batch_idx: int, data_batch=None, outputs=None) -> None:
- if self.every_n_iters is None:
- return
-
- save_eval_output = self._is_save_checkpoint(runner)
- do_chat = (save_eval_output or self.every_n_train_iters(runner, self.every_n_iters))
- if not do_chat:
- return
-
- runner.logger.info('after_train_iter in EvaluateChatHook.')
- self._generate_samples(runner, save_eval_output=save_eval_output)
-
- def after_train(self, runner):
- runner.logger.info('after_train in EvaluateChatHook.')
- self._generate_samples(runner)
-
- def after_val(self, runner) -> None:
- if self.every_n_iters is not None:
- return
- runner.logger.info('after_val in EvaluateChatHook.')
- self._generate_samples(runner)
\ No newline at end of file
diff --git a/code/xtuner/engine/hooks/hf_checkpoint_hook.py b/code/xtuner/engine/hooks/hf_checkpoint_hook.py
deleted file mode 100644
index 142af4cdbc27f34a0e4def644a742258542c2db0..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/hf_checkpoint_hook.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
-from pathlib import Path
-from typing import Optional, Union
-
-import torch.distributed as dist
-from mmengine import print_log
-from mmengine._strategy import DeepSpeedStrategy
-from mmengine.hooks import Hook
-from mmengine.model import is_model_wrapper
-from mmengine.runner import FlexibleRunner
-
-from xtuner.registry import BUILDER
-from xtuner.utils import get_origin_state_dict
-
-DATA_BATCH = Optional[Union[dict, tuple, list]]
-
-
-class HFCheckpointHook(Hook):
-
- priority = 95 # lower than CheckpointHook in MMEngine
-
- def __init__(self, out_dir: Optional[Union[str, Path]] = None) -> None:
- self.out_dir = out_dir
-
- @staticmethod
- def _use_shard_moe(llm):
- config = llm.config
- moe_implementation = getattr(config, 'moe_implementation', 'origin')
- return moe_implementation == 'shard'
-
- def after_run(self, runner) -> None:
- assert isinstance(runner,
- FlexibleRunner), 'Runner should be `FlexibleRunner`'
- assert isinstance(
- runner.strategy,
- DeepSpeedStrategy), 'Strategy should be `DeepSpeedStrategy`'
-
- if self.out_dir is None:
- self.out_dir = osp.join(runner.work_dir, 'hf_model')
-
- wrapped_model = runner.strategy.model
- if wrapped_model.zero_optimization_partition_weights():
- assert wrapped_model.zero_gather_16bit_weights_on_model_save(), \
- ('Please set `gather_16bit_weights_on_model_save=True` '
- 'in your DeepSpeed config.')
- state_dict = wrapped_model._zero3_consolidated_16bit_state_dict()
- else:
- state_dict = wrapped_model.module_state_dict(
- exclude_frozen_parameters=runner.strategy.
- exclude_frozen_parameters)
-
- model = runner.model
- if is_model_wrapper(model):
- model = model.module
- llm = model.llm
- if (not dist.is_initialized()) or dist.get_rank() == 0:
- # keys in state_dict are prefixed with 'llm.'
- keys = list(state_dict.keys())
- for k in keys:
- val = state_dict.pop(k)
- state_dict[k[4:]] = val
-
- if self._use_shard_moe(llm):
- print_log('recover the origin state_dict from merged one ...')
- state_dict = get_origin_state_dict(state_dict, llm)
-
- print_log(f'Saving LLM to {self.out_dir}')
- llm.save_pretrained(self.out_dir, state_dict=state_dict)
-
- print_log(f'Saving LLM tokenizer to {self.out_dir}')
- tokenizer = BUILDER.build(runner.cfg.tokenizer)
- tokenizer.save_pretrained(self.out_dir)
diff --git a/code/xtuner/engine/hooks/learning_rate_hook.py b/code/xtuner/engine/hooks/learning_rate_hook.py
deleted file mode 100644
index 3db9c9d6848b547c2ebb662c340e54823fcde0dc..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/learning_rate_hook.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-from typing import Optional, Union, Dict
-
-from mmengine import print_log
-from mmengine.hooks import Hook
-from mmengine.model.wrappers import is_model_wrapper
-
-DATA_BATCH = Optional[Union[dict, tuple, list]]
-
-def _build_lr_multiplier_fn(paramwise_cfg: Optional[dict]) -> Dict[str, float]:
- """Extract custom_keys -> lr_mult mapping from paramwise_cfg."""
- custom = {}
- if paramwise_cfg and "custom_keys" in paramwise_cfg:
- for k, v in paramwise_cfg["custom_keys"].items():
- # allow either {"lr_mult": x} or direct x
- if isinstance(v, dict) and "lr_mult" in v:
- custom[k] = float(v["lr_mult"])
- elif isinstance(v, (int, float)):
- custom[k] = float(v)
- return custom
-
-def _longest_prefix_lr_mult(name: str, prefix2mult: Dict[str, float]) -> float:
- """Longest-prefix-wins. Default 1.0."""
- best_len, mult = -1, 1.0
- for prefix, m in prefix2mult.items():
- if name.startswith(prefix) and len(prefix) > best_len:
- best_len, mult = len(prefix), m
- return mult
-
-class PrintInitLRHook(Hook):
- """Print initial effective LR for all trainable parameters (ZeRO-safe)."""
-
- priority = 55
-
- def __init__(self, show_group_summary: bool = True, warn_prefix_miss: bool = True):
- self.show_group_summary = show_group_summary
- self.warn_prefix_miss = warn_prefix_miss
-
- def before_train(self, runner) -> None:
- # unwrap model
- model = runner.model.module if is_model_wrapper(runner.model) else runner.model
-
- # base LR: use optim wrapper API (works with ZeRO)
- try:
- # get_lr() returns list for groups; take the first as "base"
- base_lrs = runner.optim_wrapper.get_lr()
- base_lr = float(base_lrs[0]) if isinstance(base_lrs, (list, tuple)) else float(base_lrs)
- except Exception:
- # fallback: try optimizer param_groups
- opt = runner.optim_wrapper.optimizer
- base_lr = float(opt.param_groups[0]['lr'])
-
- # Pull the same paramwise_cfg used to build groups
- paramwise_cfg = getattr(runner.cfg, 'optim_wrapper', {}).get('paramwise_cfg', None)
- prefix2mult = _build_lr_multiplier_fn(paramwise_cfg)
-
- # Optional: show ZeRO param-group summary that’s still available
- if self.show_group_summary:
- opt = runner.optim_wrapper.optimizer
- lines = ["\n===== Initial LR per (ZeRO) param group ====="]
- for i, g in enumerate(opt.param_groups):
- lr = g.get('lr', base_lr)
- lines.append(f"Group {i}: lr={float(lr):.6e} | #params={len(g.get('params', []))}")
- lines.append("============================================")
- print_log("\n".join(lines), 'current', level=logging.INFO)
-
- # Print effective LR per parameter by recomputing the multiplier via prefixes
- header = ["\n===== Initial Learning Rates per parameter (prefix-computed) ====="]
- rows = []
- missed_prefix_examples = set()
-
- for name, p in model.named_parameters():
- if not p.requires_grad:
- continue
- mult = _longest_prefix_lr_mult(name, prefix2mult)
- eff_lr = base_lr * mult
- rows.append(f"{name:60s} lr={eff_lr:.6e} (mult={mult:g})")
-
- # collect a couple of examples that are *probably* intended to match
- # but won't if prefixes are too short
- if self.warn_prefix_miss:
- # Heuristics: common blocks that users intend to scale
- for want in ("image_score_predictor", "output_text_score_predictor",
- "instruct_score_predictor", "projector", "LongNet_encoder"):
- if want in name and not any(name.startswith(px) for px in prefix2mult):
- # only add a few examples to avoid spam
- if len(missed_prefix_examples) < 6:
- missed_prefix_examples.add(name)
-
- footer = ["==================================================================="]
- print_log("\n".join(header + rows + footer), 'current', level=logging.INFO)
-
- if self.warn_prefix_miss and missed_prefix_examples:
- # Suggest exact prefixes found in your names to avoid mismatch
- suggest = set(n.split('.')[0] for n in missed_prefix_examples) # coarse
- print_log(
- "⚠️ Some names likely didn't match your `custom_keys` prefixes.\n"
- " MMEngine uses **prefix (starts-with)** matching with longest-prefix-wins.\n"
- " From your current model, candidate full prefixes include e.g.:\n"
- " - 'llm.model.image_score_predictor'\n"
- " - 'llm.model.output_text_score_predictor'\n"
- " - 'llm.model.instruct_score_predictor'\n"
- " - 'projector'\n"
- " - 'LongNet_encoder'\n"
- " Update your `paramwise_cfg.custom_keys` to these exact prefixes.",
- 'current', level=logging.WARNING
- )
\ No newline at end of file
diff --git a/code/xtuner/engine/hooks/throughput_hook.py b/code/xtuner/engine/hooks/throughput_hook.py
deleted file mode 100644
index e74c0a0acf1e13498107364cc3cf3b4797159aaf..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/throughput_hook.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-from typing import Optional, Union
-
-import torch
-from mmengine import print_log
-from mmengine.hooks import Hook
-from mmengine.model.wrappers import is_model_wrapper
-from torch.utils._pytree import tree_flatten
-
-from xtuner.parallel.sequence import get_sequence_parallel_world_size
-
-DATA_BATCH = Optional[Union[dict, tuple, list]]
-
-
-class ThroughputHook(Hook):
-
- # priority must be higher than LoggerHook (50) and lower than
- # IterTimerHook (60)
- priority = 55
-
- def __init__(self,
- use_activation_checkpointing=None,
- hidden_size=None,
- num_layers=None,
- vocab_size=None,
- mlp_ratio=None,
- is_casual=None):
- self.use_activation_checkpointing = use_activation_checkpointing
- self.hidden_size = hidden_size
- self.num_layers = num_layers
- self.vocab_size = vocab_size
- self.mlp_ratio = mlp_ratio
- self.is_casual = is_casual
-
- @staticmethod
- def _guess_is_casual_attn(model):
- for module in model.modules():
- if hasattr(module, 'is_causal'):
- return module.is_causal
- print_log(
- 'It\'s impossible to speculate whether casual attention was used, '
- 'and FLOPs will be calculated as `casual = True`.', 'current')
- return True
-
- @staticmethod
- def _get_batch_size_and_sequence_len(data_batch):
- data_list, _ = tree_flatten(data_batch)
- for data in data_list:
- if isinstance(data, torch.Tensor):
- return data.size(0), data.size(1)
- raise RuntimeError('No tensor found in the batch')
-
- @staticmethod
- def _guess_use_activation_checkpointing(model):
- for module in model.modules():
- if hasattr(module, 'gradient_checkpointing'):
- return module.gradient_checkpointing
- return False
-
- def before_run(self, runner) -> None:
- if is_model_wrapper(runner.model):
- model = runner.model.module
- else:
- model = runner.model
- self.use_activation_checkpointing = \
- (self.use_activation_checkpointing or
- self._guess_use_activation_checkpointing(model))
- self.hidden_size = self.hidden_size or model.config.hidden_size
- self.num_layers = self.num_layers or model.config.num_hidden_layers
- self.vocab_size = self.vocab_size or model.config.vocab_size
- self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
- model.config.hidden_size)
- self.mlp_ratio *= 1.5 # has gate_proj
- self.is_casual = self.is_casual if self.is_casual is not None \
- else self._guess_is_casual_attn(model)
-
- use_varlen_attn = getattr(model, 'use_varlen_attn', False)
- if use_varlen_attn:
- print_log(
- 'Using variable-length Flash Attention causes an inflation'
- ' in the FLOPs calculation.',
- 'current',
- level=logging.WARNING)
-
- return
-
- def after_train_iter(self,
- runner,
- batch_idx: int,
- data_batch: DATA_BATCH = None,
- outputs: Optional[dict] = None) -> None:
- """Calc flops based on the paper of Megatron
- https://deepakn94.github.io/assets/papers/megatron-sc21.pdf."""
-
- batch_size, sequence_len = self._get_batch_size_and_sequence_len(
- data_batch)
- sequence_parallel_size = get_sequence_parallel_world_size()
- sequence_len /= sequence_parallel_size
-
- message_hub = runner.message_hub
- iter_time = message_hub.get_scalar('train/time').current()
-
- # We consider a language model with 𝑙 transformer layers,
- # hidden size h, sequence length s, vocabulary size V, and
- # training batch size B.
- # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
- # (factor of 2 needed to account for multiplies and adds).
-
- # Attention Layer:
- # qkv_proj + o_proj: 8B * s * h^2
- # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
-
- # MLP Layer:
- # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
- # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
- # (has gate_proj))
-
- # The backward pass requires double the number of FLOPs since we
- # need to calculate the gradients with respect to both input and
- # weight tensors. In addition, we are using activation recomputation,
- # which requires an additional forward pass before the backward pass.
-
- # While sequence parallel will affect the FLOPs calculation in attn.
- # Suppose the sequence length in one GPU is s and the sequence
- # parallel world size is `sp_size`, which means the total
- # sequence length in the attention calculation is
- # `s * sp_size` and the number of attention heads decrease to
- # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
- # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
- # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
-
- flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
- flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
- sequence_parallel_size / (int(self.is_casual) + 1)
- flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
- self.hidden_size**2
- flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
- flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
- flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
- self.vocab_size
- flops_per_iteration = flops_wo_head + flops_head
-
- avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
- tokens_per_sec_per_gpu = batch_size * sequence_len / (
- iter_time + 1e-12)
-
- message_hub.update_scalar('train/tflops', avg_tflops_per_gpu)
- message_hub.update_scalar('train/tokens_per_sec',
- tokens_per_sec_per_gpu)
diff --git a/code/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/code/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
deleted file mode 100644
index fc31f21aecb44b666122db152ec6809dbaa41106..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Union
-
-from mmengine import MessageHub
-from mmengine.dist import get_rank
-from mmengine.hooks import Hook
-
-DATA_BATCH = Optional[Union[dict, tuple, list]]
-
-
-class VarlenAttnArgsToMessageHubHook(Hook):
-
- def before_train_iter(self,
- runner,
- batch_idx: int,
- data_batch: dict = None) -> None:
- rank = get_rank()
- message_hub = MessageHub.get_instance('varlen_attn_args')
-
- assert 'data' in data_batch.keys()
- data = data_batch['data']
-
- cumulative_len = data.pop('cumulative_len')
- assert len(cumulative_len) == 1
- cumulative_len = cumulative_len[0].cuda()
- message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)
-
- max_seqlen = data.pop('max_seqlen')
- message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
-
- def after_train_iter(self,
- runner,
- batch_idx: int,
- data_batch: DATA_BATCH = None,
- outputs: Optional[dict] = None) -> None:
- rank = get_rank()
- message_hub = MessageHub.get_instance('varlen_attn_args')
- message_hub.update_info(f'cumulative_len_rank_{rank}', None)
- message_hub.update_info(f'max_seqlen_rank_{rank}', None)
-
- def before_val_iter(self,
- runner,
- batch_idx: int,
- data_batch: DATA_BATCH = None) -> None:
- """All subclasses should override this method, if they need any
- operations before each validation iteration.
-
- Args:
- runner (Runner): The runner of the validation process.
- batch_idx (int): The index of the current batch in the val loop.
- data_batch (dict, optional): Data from dataloader.
- Defaults to None.
- """
- rank = get_rank()
- message_hub = MessageHub.get_instance('varlen_attn_args')
-
- assert 'data' in data_batch.keys()
- data = data_batch['data']
-
- cumulative_len = data.pop('cumulative_len')
- assert len(cumulative_len) == 1
- cumulative_len = cumulative_len[0].cuda()
- message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)
-
- max_seqlen = data.pop('max_seqlen')
- message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
-
- def after_val_iter(self,
- runner,
- batch_idx,
- data_batch=None,
- outputs=None) -> None:
- """All subclasses should override this method, if they need any
- operations after each validation iteration.
-
- Args:
- runner (Runner): The runner of the validation process.
- batch_idx (int): The index of the current batch in the val loop.
- data_batch (dict or tuple or list, optional): Data from dataloader.
- outputs (Sequence, optional): Outputs from model.
- """
- rank = get_rank()
- message_hub = MessageHub.get_instance('varlen_attn_args')
- message_hub.update_info(f'cumulative_len_rank_{rank}', None)
- message_hub.update_info(f'max_seqlen_rank_{rank}', None)
diff --git a/code/xtuner/engine/hooks/visual_warmup_book.py b/code/xtuner/engine/hooks/visual_warmup_book.py
deleted file mode 100644
index 429adf2f0dec3aa398d0cec53f11c4e0f1bd2368..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/hooks/visual_warmup_book.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# xtuner/engine/hooks/two_phase_visual_warmup_hook.py
-from __future__ import annotations
-import copy
-from typing import Dict, Any
-
-from mmengine.hooks import Hook
-from mmengine.registry import HOOKS
-
-
-def _deep_update(d: Dict[str, Any], path: str, value: Any):
- """Set d['a']['b']['c'] = value for path='a.b.c'."""
- cur = d
- keys = path.split('.')
- for k in keys[:-1]:
- if k not in cur or not isinstance(cur[k], dict):
- cur[k] = {}
- cur = cur[k]
- cur[keys[-1]] = value
-
-
-@HOOKS.register_module()
-class TwoPhaseVisualWarmupHook(Hook):
- """
- Swap to a warmup train dataloader for the first `warmup_epochs` epochs,
- then restore the original train dataloader.
-
- It rebuilds the dataloader from the config so that worker processes
- are spawned with the warmup dataset params (no post-fork mutation).
-
- Args:
- warmup_epochs (int): #epochs (starting from epoch=0) to use warmup loader.
- warmup_sample_num (int): visual tokens (patches) per image during warmup.
- align_per_image_length (bool): also set dataset.per_image_length to warmup_sample_num.
- warmup_strategy (str): 'linspace' | 'random' | 'random_full' (your new mode).
- apply_to_val (bool): if True, also rebuild/replace the val dataloader in warmup epochs.
- dataset_overrides (dict): extra dotted-key overrides applied into the warmup
- train_dataloader config, e.g. {"dataset.debug_seed": 3407}.
- """
-
- priority = 50 # run after IterTimerHook but before evaluation hooks etc.
-
- def __init__(self,
- warmup_epochs: int = 1,
- warmup_sample_num: int = 4096,
- align_per_image_length: bool = True,
- warmup_strategy: str = 'random_full',
- apply_to_val: bool = False,
- dataset_overrides: Dict[str, Any] | None = None):
- assert warmup_epochs >= 1
- assert warmup_sample_num > 0
- assert warmup_strategy in ('linspace', 'random', 'random_full')
- self.warmup_epochs = int(warmup_epochs)
- self.warmup_sample_num = int(warmup_sample_num)
- self.align_per_image_length = bool(align_per_image_length)
- self.warmup_strategy = warmup_strategy
- self.apply_to_val = bool(apply_to_val)
- self.dataset_overrides = dataset_overrides or {}
-
- # Internals
- self._orig_train_dl = None
- self._warmup_train_dl = None
- self._orig_val_dl = None
- self._warmup_val_dl = None
- self._installed = False
- self._switched_to_main = False
-
- # ---------- helpers ----------
- def _build_warmup_dl_cfg(self, base_cfg: Dict[str, Any]) -> Dict[str, Any]:
- """Clone a dataloader cfg and inject warmup dataset params."""
- cfg = copy.deepcopy(base_cfg)
-
- # Most configs look like: cfg['dataset'] = { type=..., sample_num=..., per_image_length=..., sample_strategy=... }
- _deep_update(cfg, 'dataset.sample_num', self.warmup_sample_num)
- if self.align_per_image_length:
- _deep_update(cfg, 'dataset.per_image_length', self.warmup_sample_num)
- _deep_update(cfg, 'dataset.sample_strategy', self.warmup_strategy)
-
- # Apply any caller-provided overrides.
- for k, v in self.dataset_overrides.items():
- _deep_update(cfg, k, v)
-
- return cfg
-
- def _build_dl(self, runner, cfg_dict: Dict[str, Any]):
- """Use runner's builder to create a dataloader from a config dict."""
- # MMEngine runners expose build_dataloader on Runner.
- return runner.build_dataloader(cfg_dict)
-
- def _install_warmup_train(self, runner):
- if self._installed:
- return
-
- # Keep a handle to the original dataloaders built by Runner
- # NOTE: `train_dataloader` is a read-only property; access via loop.obj
- self._orig_train_dl = runner.train_loop.dataloader
- train_cfg = copy.deepcopy(runner.cfg.train_dataloader)
-
- warm_cfg = self._build_warmup_dl_cfg(train_cfg)
- self._warmup_train_dl = self._build_dl(runner, warm_cfg)
-
- # Swap by assigning into the loop (NOT runner.train_dataloader)
- runner.train_loop.dataloader = self._warmup_train_dl
-
- if self.apply_to_val and hasattr(runner.cfg, 'val_dataloader'):
- try:
- self._orig_val_dl = getattr(runner, 'val_dataloader', None)
- val_cfg = copy.deepcopy(runner.cfg.val_dataloader)
- warm_val_cfg = self._build_warmup_dl_cfg(val_cfg)
- self._warmup_val_dl = self._build_dl(runner, warm_val_cfg)
- if hasattr(runner, 'val_loop') and runner.val_loop is not None:
- runner.val_loop.dataloader = self._warmup_val_dl
- except Exception as e:
- runner.logger.warning(f'[TwoPhaseVisualWarmupHook] Failed to build warmup val dataloader: {e}')
-
- runner.logger.info(
- f'[TwoPhaseVisualWarmupHook] Installed warmup dataloader: '
- f'sample_num={self.warmup_sample_num}, '
- f'per_image_length={"aligned" if self.align_per_image_length else "kept"}, '
- f'strategy={self.warmup_strategy}'
- )
- self._installed = True
-
- def _switch_to_main(self, runner):
- if self._switched_to_main:
- return
- if self._orig_train_dl is not None:
- runner.train_loop.dataloader = self._orig_train_dl
- if self.apply_to_val and self._orig_val_dl is not None and hasattr(runner, 'val_loop') and runner.val_loop is not None:
- runner.val_loop.dataloader = self._orig_val_dl
- runner.logger.info('[TwoPhaseVisualWarmupHook] Switched back to main dataloader.')
- self._switched_to_main = True
-
- # ---------- lifecycle ----------
- def before_train(self, runner):
- # Build and install warmup dataloader(s) BEFORE training starts.
- self._install_warmup_train(runner)
-
- def before_train_epoch(self, runner):
- # Use warmup dataloader for epochs [0, warmup_epochs-1], then switch.
- if runner.epoch >= self.warmup_epochs:
- self._switch_to_main(runner)
- else:
- # Ensure warmup is installed (defensive)
- if not self._installed:
- self._install_warmup_train(runner)
-
- def after_train(self, runner):
- # Restore main dataloader to leave runner in a clean state.
- if not self._switched_to_main:
- self._switch_to_main(runner)
\ No newline at end of file
diff --git a/code/xtuner/engine/optimizer/__init__.py b/code/xtuner/engine/optimizer/__init__.py
deleted file mode 100644
index 6b84363034b95d7fb639512c07d07a845b31ffd7..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/optimizer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .muon_wrapper import MuonOptimWrapperConstructor
-
-__all__ = ['MuonOptimWrapperConstructor']
\ No newline at end of file
diff --git a/code/xtuner/engine/optimizer/__pycache__/__init__.cpython-311.pyc b/code/xtuner/engine/optimizer/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index e4d24a901ea6b96efd0bf1eb823212a1fcfe82c1..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/optimizer/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/optimizer/__pycache__/muon_wrapper.cpython-311.pyc b/code/xtuner/engine/optimizer/__pycache__/muon_wrapper.cpython-311.pyc
deleted file mode 100644
index b427bc60138740ed64437040d43fd76862a443c7..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/optimizer/__pycache__/muon_wrapper.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/optimizer/muon_wrapper.py b/code/xtuner/engine/optimizer/muon_wrapper.py
deleted file mode 100644
index bb4a7a5906533373e90c154753188a92b85d9f1e..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/optimizer/muon_wrapper.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# BEFORE (bad): you were doing this
-# return build_optim_wrapper(model, wrapper_cfg)
-
-# AFTER (good): just return the optimizer object
-from typing import List
-import torch
-from copy import deepcopy
-from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS
-
-@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
-class MuonOptimWrapperConstructor:
- def __init__(self,
- optim_wrapper_cfg: dict,
- paramwise_cfg=None,
- muon_hidden_lr: float = 2e-2,
- muon_weight_decay: float = 1e-2,
- adamw_lr: float = 3e-4,
- adamw_betas=(0.9, 0.95),
- adamw_weight_decay: float = 1e-2,
- keep_lora_on_adamw: bool = False,
- **kwargs):
- # keep this; builder will wrap around whatever we return
- self.optim_wrapper_cfg = deepcopy(optim_wrapper_cfg)
- self.paramwise_cfg = paramwise_cfg
- self.muon_hidden_lr = muon_hidden_lr
- self.muon_weight_decay = muon_weight_decay
- self.adamw_lr = adamw_lr
- self.adamw_betas = adamw_betas
- self.adamw_weight_decay = adamw_weight_decay
- self.keep_lora_on_adamw = keep_lora_on_adamw
-
- def _trainable(self, m):
- return [p for p in m.parameters() if p.requires_grad]
-
- def __call__(self, model):
- module = model.module if hasattr(model, 'module') else model
-
- # collect trainable params from LongNet, projector, LLM (LoRA)
- params = []
- for name in ('LongNet_encoder', 'projector', 'llm'):
- if hasattr(module, name):
- params += self._trainable(getattr(module, name))
-
- def is_lora_param(p):
- # wire your own detector if needed
- return getattr(p, '_is_lora_param', False)
-
- hidden = []
- non_hidden = []
- for p in params:
- if self.keep_lora_on_adamw and is_lora_param(p):
- non_hidden.append(p)
- elif getattr(p, 'ndim', 0) >= 2:
- hidden.append(p)
- else:
- non_hidden.append(p)
-
- from muon import MuonWithAuxAdam
- param_groups = [
- dict(params=hidden, use_muon=True,
- lr=self.muon_hidden_lr, weight_decay=self.muon_weight_decay),
- dict(params=non_hidden, use_muon=False,
- lr=self.adamw_lr, betas=self.adamw_betas,
- weight_decay=self.adamw_weight_decay),
- ]
- optimizer = MuonWithAuxAdam(param_groups)
-
- # IMPORTANT: return the optimizer, not the wrapper
- return optimizer
\ No newline at end of file
diff --git a/code/xtuner/engine/runner/__init__.py b/code/xtuner/engine/runner/__init__.py
deleted file mode 100644
index d8d1c582b531e341dfbb299e56cbbd3db0b81e16..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/runner/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .loops import TrainLoop
-
-__all__ = ['TrainLoop']
diff --git a/code/xtuner/engine/runner/__pycache__/__init__.cpython-311.pyc b/code/xtuner/engine/runner/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index fd976b441d6f9e2856ca9c6fab168c9f102f8f23..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/runner/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/runner/__pycache__/loops.cpython-311.pyc b/code/xtuner/engine/runner/__pycache__/loops.cpython-311.pyc
deleted file mode 100644
index 9b30ec37481731764b26894cda16d928c9baedb2..0000000000000000000000000000000000000000
Binary files a/code/xtuner/engine/runner/__pycache__/loops.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/engine/runner/loops.py b/code/xtuner/engine/runner/loops.py
deleted file mode 100644
index aeb6be31ae6e09c32fb27f60c82690d4fc94b84a..0000000000000000000000000000000000000000
--- a/code/xtuner/engine/runner/loops.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, Optional, Union
-
-from mmengine.runner import IterBasedTrainLoop
-from torch.utils.data import DataLoader
-
-
-class TrainLoop(IterBasedTrainLoop):
-
- def __init__(self,
- runner,
- dataloader: Union[DataLoader, Dict],
- max_iters: Optional[int] = None,
- max_epochs: Union[int, float] = None,
- **kwargs) -> None:
-
- if max_iters is None and max_epochs is None:
- raise RuntimeError('Please specify the `max_iters` or '
- '`max_epochs` in `train_cfg`.')
- elif max_iters is not None and max_epochs is not None:
- raise RuntimeError('Only one of `max_iters` or `max_epochs` can '
- 'exist in `train_cfg`.')
- else:
- if max_iters is not None:
- iters = int(max_iters)
- assert iters == max_iters, ('`max_iters` should be a integer '
- f'number, but get {max_iters}')
- elif max_epochs is not None:
- if isinstance(dataloader, dict):
- diff_rank_seed = runner._randomness_cfg.get(
- 'diff_rank_seed', False)
- dataloader = runner.build_dataloader(
- dataloader,
- seed=runner.seed,
- diff_rank_seed=diff_rank_seed)
- iters = max_epochs * len(dataloader)
- else:
- raise NotImplementedError
- super().__init__(
- runner=runner, dataloader=dataloader, max_iters=iters, **kwargs)
diff --git a/code/xtuner/entry_point.py b/code/xtuner/entry_point.py
deleted file mode 100644
index 68d1d1d8f1c13c28ba6114cb4de9a01865daf935..0000000000000000000000000000000000000000
--- a/code/xtuner/entry_point.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-import os
-import random
-import subprocess
-import sys
-
-from mmengine.logging import print_log
-
-import xtuner
-
-# Define valid modes
-MODES = ('list-cfg', 'copy-cfg', 'log-dataset', 'check-custom-dataset',
- 'train', 'test', 'chat', 'convert', 'preprocess', 'mmbench',
- 'eval_refcoco', "test_token_compressor", "test_fusion_compressor","test_llm_only",
- "test_random", "test_dynamic_llava", "test_fastv_llava", "test_wsi_llava")
-
-CLI_HELP_MSG = \
- f"""
- Arguments received: {str(['xtuner'] + sys.argv[1:])}. xtuner commands use the following syntax:
-
- xtuner MODE MODE_ARGS ARGS
-
- Where MODE (required) is one of {MODES}
- MODE_ARG (optional) is the argument for specific mode
- ARGS (optional) are the arguments for specific command
-
- Some usages for xtuner commands: (See more by using -h for specific command!)
-
- 1. List all predefined configs:
- xtuner list-cfg
- 2. Copy a predefined config to a given path:
- xtuner copy-cfg $CONFIG $SAVE_FILE
- 3-1. Fine-tune LLMs by a single GPU:
- xtuner train $CONFIG
- 3-2. Fine-tune LLMs by multiple GPUs:
- NPROC_PER_NODE=$NGPUS NNODES=$NNODES NODE_RANK=$NODE_RANK PORT=$PORT ADDR=$ADDR xtuner dist_train $CONFIG $GPUS
- 4-1. Convert the pth model to HuggingFace's model:
- xtuner convert pth_to_hf $CONFIG $PATH_TO_PTH_MODEL $SAVE_PATH_TO_HF_MODEL
- 4-2. Merge the HuggingFace's adapter to the pretrained base model:
- xtuner convert merge $LLM $ADAPTER $SAVE_PATH
- xtuner convert merge $CLIP $ADAPTER $SAVE_PATH --is-clip
- 4-3. Split HuggingFace's LLM to the smallest sharded one:
- xtuner convert split $LLM $SAVE_PATH
- 5-1. Chat with LLMs with HuggingFace's model and adapter:
- xtuner chat $LLM --adapter $ADAPTER --prompt-template $PROMPT_TEMPLATE --system-template $SYSTEM_TEMPLATE
- 5-2. Chat with VLMs with HuggingFace's model and LLaVA:
- xtuner chat $LLM --llava $LLAVA --visual-encoder $VISUAL_ENCODER --image $IMAGE --prompt-template $PROMPT_TEMPLATE --system-template $SYSTEM_TEMPLATE
- 6-1. Preprocess arxiv dataset:
- xtuner preprocess arxiv $SRC_FILE $DST_FILE --start-date $START_DATE --categories $CATEGORIES
- 6-2. Preprocess refcoco dataset:
- xtuner preprocess refcoco --ann-path $RefCOCO_ANN_PATH --image-path $COCO_IMAGE_PATH --save-path $SAVE_PATH
- 7-1. Log processed dataset:
- xtuner log-dataset $CONFIG
- 7-2. Verify the correctness of the config file for the custom dataset:
- xtuner check-custom-dataset $CONFIG
- 8. MMBench evaluation:
- xtuner mmbench $LLM --llava $LLAVA --visual-encoder $VISUAL_ENCODER --prompt-template $PROMPT_TEMPLATE --data-path $MMBENCH_DATA_PATH
- 9. Refcoco evaluation:
- xtuner eval_refcoco $LLM --llava $LLAVA --visual-encoder $VISUAL_ENCODER --prompt-template $PROMPT_TEMPLATE --data-path $REFCOCO_DATA_PATH
- 10. List all dataset formats which are supported in XTuner
-
- Run special commands:
-
- xtuner help
- xtuner version
-
- GitHub: https://github.com/InternLM/xtuner
- """ # noqa: E501
-
-
-CONVERT_HELP_MSG = \
- f"""
- Arguments received: {str(['xtuner'] + sys.argv[1:])}. xtuner commands use the following syntax:
-
- xtuner MODE MODE_ARGS ARGS
-
- Where MODE (required) is one of {MODES}
- MODE_ARG (optional) is the argument for specific mode
- ARGS (optional) are the arguments for specific command
-
- Some usages for convert: (See more by using -h for specific command!)
-
- 1. Convert the pth model to HuggingFace's model:
- xtuner convert pth_to_hf $CONFIG $PATH_TO_PTH_MODEL $SAVE_PATH_TO_HF_MODEL
- 2. Merge the HuggingFace's adapter to the pretrained LLM:
- xtuner convert merge $LLM $ADAPTER $SAVE_PATH
- 3. Split HuggingFace's LLM to the smallest sharded one:
- xtuner convert split $LLM $SAVE_PATH
-
- GitHub: https://github.com/InternLM/xtuner
- """ # noqa: E501
-
-
-PREPROCESS_HELP_MSG = \
- f"""
- Arguments received: {str(['xtuner'] + sys.argv[1:])}. xtuner commands use the following syntax:
-
- xtuner MODE MODE_ARGS ARGS
-
- Where MODE (required) is one of {MODES}
- MODE_ARG (optional) is the argument for specific mode
- ARGS (optional) are the arguments for specific command
-
- Some usages for preprocess: (See more by using -h for specific command!)
-
- 1. Preprocess arxiv dataset:
- xtuner preprocess arxiv $SRC_FILE $DST_FILE --start-date $START_DATE --categories $CATEGORIES
- 2. Preprocess refcoco dataset:
- xtuner preprocess refcoco --ann-path $RefCOCO_ANN_PATH --image-path $COCO_IMAGE_PATH --save-path $SAVE_PATH
-
- GitHub: https://github.com/InternLM/xtuner
- """ # noqa: E501
-
-special = {
- 'help': lambda: print_log(CLI_HELP_MSG, 'current'),
- 'version': lambda: print_log(xtuner.__version__, 'current')
-}
-special = {
- **special,
- **{f'-{k[0]}': v
- for k, v in special.items()},
- **{f'--{k}': v
- for k, v in special.items()}
-}
-
-
-def list_dataset_format():
- from xtuner.tools import list_dataset_format
- return list_dataset_format.__file__
-
-
-def list_cfg():
- from xtuner.tools import list_cfg
- return list_cfg.__file__
-
-
-def copy_cfg():
- from xtuner.tools import copy_cfg
- return copy_cfg.__file__
-
-
-def log_dataset():
- from xtuner.tools import log_dataset
- return log_dataset.__file__
-
-
-def check_custom_dataset():
- from xtuner.tools import check_custom_dataset
- return check_custom_dataset.__file__
-
-
-def train():
- from xtuner.tools import train
- return train.__file__
-
-
-def test():
- from xtuner.tools import test
- return test.__file__
-
-def test_token_compressor():
- from xtuner.tools import test_token_compressor
- return test_token_compressor.__file__
-
-def test_llm_only():
- from xtuner.tools import test_llm_only
- return test_llm_only.__file__
-
-def test_random():
- from xtuner.tools import test_random
- return test_random.__file__
-
-def test_wsi_llava():
- from xtuner.tools import test_for_json_files
- return test_for_json_files.__file__
-
-def test_dynamic_llava():
- from xtuner.tools import test_dynamic_llava
- return test_dynamic_llava.__file__
-
-def test_fusion_compressor():
- from xtuner.tools import test_fusion_compressor
- return test_fusion_compressor.__file__
-
-def test_fastv_llava():
- from xtuner.tools import test_fastv_llava
- return test_fastv_llava.__file__
-
-
-def chat():
- from xtuner.tools import chat
- return chat.__file__
-
-
-def mmbench():
- from xtuner.tools import mmbench
- return mmbench.__file__
-
-
-def pth_to_hf():
- from xtuner.tools.model_converters import pth_to_hf
- return pth_to_hf.__file__
-
-
-def merge():
- from xtuner.tools.model_converters import merge
- return merge.__file__
-
-
-def split():
- from xtuner.tools.model_converters import split
- return split.__file__
-
-
-def arxiv_preprocess():
- from xtuner.tools.data_preprocess import arxiv as arxiv_preprocess
- return arxiv_preprocess.__file__
-
-
-def convert_refcoco():
- from xtuner.tools.data_preprocess import convert_refcoco
- return convert_refcoco.__file__
-
-
-def convert_help_msg():
- print_log(CONVERT_HELP_MSG, 'current')
-
-
-def preprocess_help_msg():
- print_log(PREPROCESS_HELP_MSG, 'current')
-
-
-def eval_refcoco():
- from xtuner.tools import eval_refcoco
- return eval_refcoco.__file__
-
-
-modes = {
- 'list-cfg': list_cfg,
- 'copy-cfg': copy_cfg,
- 'log-dataset': log_dataset,
- 'check-custom-dataset': check_custom_dataset,
- 'train': train,
- 'test': test,
- "test_token_compressor": test_token_compressor,
- "test_fusion_compressor": test_fusion_compressor,
- "test_llm_only": test_llm_only,
- "test_random": test_random,
- "test_wsi_llava": test_wsi_llava,
- "test_dynamic_llava": test_dynamic_llava,
- "test_fastv_llava": test_fastv_llava,
- 'chat': chat,
- 'mmbench': mmbench,
- 'convert': {
- 'pth_to_hf': pth_to_hf,
- 'merge': merge,
- 'split': split,
- '--help': convert_help_msg,
- '-h': convert_help_msg
- },
- 'preprocess': {
- 'arxiv': arxiv_preprocess,
- 'refcoco': convert_refcoco,
- '--help': preprocess_help_msg,
- '-h': preprocess_help_msg
- },
- 'eval_refcoco': eval_refcoco,
- 'list-dataset-format': list_dataset_format
-}
-
-HELP_FUNCS = [preprocess_help_msg, convert_help_msg]
-MAP_FILE_FUNCS = [
- list_cfg, copy_cfg, log_dataset, check_custom_dataset, train, test, chat,
- mmbench, pth_to_hf, merge, split, arxiv_preprocess, eval_refcoco,
- convert_refcoco, list_dataset_format, test_token_compressor, test_fusion_compressor,
- test_llm_only, test_random, test_wsi_llava,
- test_dynamic_llava, test_fastv_llava
-]
-
-
-def cli():
- args = sys.argv[1:]
- if not args: # no arguments passed
- print_log(CLI_HELP_MSG, 'current')
- return
- if args[0].lower() in special:
- special[args[0].lower()]()
- return
- elif args[0].lower() in modes:
- try:
- fn_or_dict = modes[args[0].lower()]
- n_arg = 0
-
- if isinstance(fn_or_dict, dict):
- n_arg += 1
- fn = fn_or_dict[args[n_arg].lower()]
- else:
- fn = fn_or_dict
-
- assert callable(fn)
-
- if fn in HELP_FUNCS:
- fn()
- else:
- slurm_launcher = False
- for i in range(n_arg + 1, len(args)):
- if args[i] == '--launcher':
- if i + 1 < len(args) and args[i + 1] == 'slurm':
- slurm_launcher = True
- break
- nnodes = int(os.environ.get('NNODES', 1))
- nproc_per_node = int(os.environ.get('NPROC_PER_NODE', 1))
- if slurm_launcher or (nnodes == 1 and nproc_per_node == 1):
- subprocess.run(['python', fn()] + args[n_arg + 1:])
- else:
- port = os.environ.get('PORT', None)
- if port is None:
- port = random.randint(20000, 29999)
- print_log(f'Use random port: {port}', 'current',
- logging.WARNING)
- torchrun_args = [
- f'--nnodes={nnodes}',
- f"--node_rank={os.environ.get('NODE_RANK', 0)}",
- f'--nproc_per_node={nproc_per_node}',
- f"--master_addr={os.environ.get('ADDR', '127.0.0.1')}",
- f'--master_port={port}'
- ]
- subprocess.run(['torchrun'] + torchrun_args + [fn()] +
- args[n_arg + 1:] +
- ['--launcher', 'pytorch'])
- except Exception as e:
- print_log(f"WARNING: command error: '{e}'!", 'current',
- logging.WARNING)
- print_log(CLI_HELP_MSG, 'current', logging.WARNING)
- return
- else:
- print_log('WARNING: command error!', 'current', logging.WARNING)
- print_log(CLI_HELP_MSG, 'current', logging.WARNING)
- return
diff --git a/code/xtuner/evaluation/.DS_Store b/code/xtuner/evaluation/.DS_Store
deleted file mode 100644
index 4fc9c8a0b88810f82461b8395d483c5eb5d2e832..0000000000000000000000000000000000000000
Binary files a/code/xtuner/evaluation/.DS_Store and /dev/null differ
diff --git a/code/xtuner/evaluation/__init__.py b/code/xtuner/evaluation/__init__.py
deleted file mode 100644
index fba3e590598c3fe175f9d331e0da8883c1ef4ea8..0000000000000000000000000000000000000000
--- a/code/xtuner/evaluation/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .metrics import MMLUMetric
-
-__all__ = ['MMLUMetric']
diff --git a/code/xtuner/evaluation/metrics/__init__.py b/code/xtuner/evaluation/metrics/__init__.py
deleted file mode 100644
index f3efc80fd5d8aa3f7b65e43ec1a8acd98a1df3bb..0000000000000000000000000000000000000000
--- a/code/xtuner/evaluation/metrics/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .mmlu_metric import MMLUMetric
-
-__all__ = ['MMLUMetric']
diff --git a/code/xtuner/evaluation/metrics/mmlu_metric.py b/code/xtuner/evaluation/metrics/mmlu_metric.py
deleted file mode 100644
index ad1282056a8e7691f05f579275ad0bf990796f12..0000000000000000000000000000000000000000
--- a/code/xtuner/evaluation/metrics/mmlu_metric.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Sequence
-
-import numpy as np
-import torch
-from mmengine.evaluator import BaseMetric
-from mmengine.logging import print_log
-from rich.console import Console
-from rich.table import Table
-
-from xtuner.registry import BUILDER
-
-
-class MMLUMetric(BaseMetric):
- METAINFO = {
- 'subcategories': {
- 'abstract_algebra': ['math'],
- 'anatomy': ['health'],
- 'astronomy': ['physics'],
- 'business_ethics': ['business'],
- 'clinical_knowledge': ['health'],
- 'college_biology': ['biology'],
- 'college_chemistry': ['chemistry'],
- 'college_computer_science': ['computer science'],
- 'college_mathematics': ['math'],
- 'college_medicine': ['health'],
- 'college_physics': ['physics'],
- 'computer_security': ['computer science'],
- 'conceptual_physics': ['physics'],
- 'econometrics': ['economics'],
- 'electrical_engineering': ['engineering'],
- 'elementary_mathematics': ['math'],
- 'formal_logic': ['philosophy'],
- 'global_facts': ['other'],
- 'high_school_biology': ['biology'],
- 'high_school_chemistry': ['chemistry'],
- 'high_school_computer_science': ['computer science'],
- 'high_school_european_history': ['history'],
- 'high_school_geography': ['geography'],
- 'high_school_government_and_politics': ['politics'],
- 'high_school_macroeconomics': ['economics'],
- 'high_school_mathematics': ['math'],
- 'high_school_microeconomics': ['economics'],
- 'high_school_physics': ['physics'],
- 'high_school_psychology': ['psychology'],
- 'high_school_statistics': ['math'],
- 'high_school_us_history': ['history'],
- 'high_school_world_history': ['history'],
- 'human_aging': ['health'],
- 'human_sexuality': ['culture'],
- 'international_law': ['law'],
- 'jurisprudence': ['law'],
- 'logical_fallacies': ['philosophy'],
- 'machine_learning': ['computer science'],
- 'management': ['business'],
- 'marketing': ['business'],
- 'medical_genetics': ['health'],
- 'miscellaneous': ['other'],
- 'moral_disputes': ['philosophy'],
- 'moral_scenarios': ['philosophy'],
- 'nutrition': ['health'],
- 'philosophy': ['philosophy'],
- 'prehistory': ['history'],
- 'professional_accounting': ['other'],
- 'professional_law': ['law'],
- 'professional_medicine': ['health'],
- 'professional_psychology': ['psychology'],
- 'public_relations': ['politics'],
- 'security_studies': ['politics'],
- 'sociology': ['culture'],
- 'us_foreign_policy': ['politics'],
- 'virology': ['health'],
- 'world_religions': ['philosophy'],
- },
- 'categories': {
- 'STEM': [
- 'physics', 'chemistry', 'biology', 'computer science', 'math',
- 'engineering'
- ],
- 'humanities': ['history', 'philosophy', 'law'],
- 'social sciences':
- ['politics', 'culture', 'economics', 'geography', 'psychology'],
- 'other (business, health, misc.)': ['other', 'business', 'health'],
- },
- }
- METAINFO['subcategories_list'] = list({
- subcat
- for subcats in METAINFO['subcategories'].values() for subcat in subcats
- })
-
- def __init__(self, tokenizer, *args, **kwargs):
- super().__init__(*args, **kwargs)
- tokenizer = BUILDER.build(tokenizer)
- self.abcd_idx = [
- tokenizer.encode('A', add_special_tokens=False)[0],
- tokenizer.encode('B', add_special_tokens=False)[0],
- tokenizer.encode('C', add_special_tokens=False)[0],
- tokenizer.encode('D', add_special_tokens=False)[0],
- ]
-
- @staticmethod
- def ABCD_to_0123(abcd):
- return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd]
-
- @staticmethod
- def find_first_zero_index(tensor):
- indices = torch.nonzero(tensor == 0)
- if indices.numel() > 0:
- return indices[0].item()
- else:
- return None
-
- @staticmethod
- def accuracy(preds, gts):
- """Computes the accuracy for preds and gts."""
- correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)]
- acc = np.mean(correct) * 100
- return acc
-
- def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
- """Process one batch of data samples and predictions. The processed
- results should be stored in ``self.results``, which will be used to
- compute the metrics when all batches have been processed.
-
- Args:
- data_batch (Any): A batch of data from the dataloader.
- data_samples (Sequence[dict]): A batch of outputs from
- the model.
- """
- subjects = data_batch['data_samples']['subjects']
- gts = [
- self.ABCD_to_0123(gt)
- for gt in data_batch['data_samples']['labels']
- ]
- preds = []
- for sample, attn_mask, subject, gt in zip(
- data_samples, data_batch['data']['attention_mask'], subjects,
- gts):
- pred_logits = sample['logits']
- first_zero_idx = self.find_first_zero_index(attn_mask)
- pred_idx = -1 if first_zero_idx is None else first_zero_idx - 1
- pred_logtis_abcd = pred_logits[pred_idx, self.abcd_idx]
- pred = torch.argmax(pred_logtis_abcd).item()
- preds.append(pred)
- self.results.append((subject, pred, gt))
-
- def compute_metrics(self, results: list) -> dict:
- """Compute the metrics from processed results.
-
- Args:
- results (list): The processed results of each batch.
-
- Returns:
- dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
- subjects_results = {
- subject: {
- 'preds': [],
- 'gts': []
- }
- for subject in self.METAINFO['subcategories'].keys()
- }
- subcats_results = {
- subcat: {
- 'preds': [],
- 'gts': []
- }
- for subcat in self.METAINFO['subcategories_list']
- }
- cats_results = {
- cat: {
- 'preds': [],
- 'gts': []
- }
- for cat in self.METAINFO['categories'].keys()
- }
- for subject, pred, gt in results:
- subjects_results[subject]['preds'].append(pred)
- subjects_results[subject]['gts'].append(gt)
- subcats = self.METAINFO['subcategories'][subject]
- for subcat in subcats:
- subcats_results[subcat]['preds'].append(pred)
- subcats_results[subcat]['gts'].append(gt)
- for cat, subcats in self.METAINFO['categories'].items():
- for subcat in subcats:
- if subcat in subcats_results:
- cats_results[cat]['preds'].extend(
- subcats_results[subcat]['preds'])
- cats_results[cat]['gts'].extend(
- subcats_results[subcat]['gts'])
-
- subjects_metrics = dict()
- subcats_metrics = dict()
- cats_metrics = dict()
- for subject in self.METAINFO['subcategories'].keys():
- assert len(subjects_results[subject]['preds']) == len(
- subjects_results[subject]['gts'])
- if len(subjects_results[subject]['preds']) == 0:
- print_log(f'Skip subject {subject} for mmlu', 'current')
- else:
- score = self.accuracy(subjects_results[subject]['preds'],
- subjects_results[subject]['gts'])
- subjects_metrics[f'{subject}'] = score
- for subcat in self.METAINFO['subcategories_list']:
- assert len(subcats_results[subcat]['preds']) == len(
- subcats_results[subcat]['gts'])
- if len(subcats_results[subcat]['preds']) == 0:
- print_log(f'Skip subcategory {subcat} for mmlu', 'current')
- else:
- score = self.accuracy(subcats_results[subcat]['preds'],
- subcats_results[subcat]['gts'])
- subcats_metrics[f'{subcat}'] = score
- for cat in self.METAINFO['categories'].keys():
- assert len(cats_results[cat]['preds']) == len(
- cats_results[cat]['gts'])
- if len(cats_results[cat]['preds']) == 0:
- print_log(f'Skip category {cat} for mmlu', 'current')
- else:
- score = self.accuracy(cats_results[cat]['preds'],
- cats_results[cat]['gts'])
- cats_metrics[f'{cat}'] = score
-
- metrics = dict()
- metrics.update(subjects_metrics)
- metrics.update(subcats_metrics)
- metrics.update(cats_metrics)
- metrics['average'] = np.mean(list(subjects_metrics.values()))
-
- table_metrics = dict()
- table_metrics.update(cats_metrics)
- table_metrics['average'] = np.mean(list(subjects_metrics.values()))
- self._print_results(table_metrics)
- return metrics
-
- def _print_results(self, table_metrics: dict) -> None:
- table_title = ' MMLU Benchmark '
- table = Table(title=table_title)
- console = Console()
- table.add_column('Categories', justify='left')
- table.add_column('Accuracy (%)', justify='right')
- for cat, acc in table_metrics.items():
- table.add_row(cat, f'{acc:.1f}')
- with console.capture() as capture:
- console.print(table, end='')
- print_log('\n' + capture.get(), 'current')
diff --git a/code/xtuner/evaluation/metrics/reward_metric.py b/code/xtuner/evaluation/metrics/reward_metric.py
deleted file mode 100644
index c5d019978c9ebbfe2debd42b113f64aba9274423..0000000000000000000000000000000000000000
--- a/code/xtuner/evaluation/metrics/reward_metric.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import itertools
-from collections import defaultdict
-from typing import List, Optional, Sequence
-
-import torch
-from mmengine.evaluator import BaseMetric
-from mmengine.logging import print_log
-from rich.console import Console
-from rich.table import Table
-
-
-class RewardMetric(BaseMetric):
- r"""Reward model evaluation metric.
- """
- default_prefix: Optional[str] = ''
-
- def __init__(self,
- collect_device: str = 'cpu',
- prefix: Optional[str] = None) -> None:
- super().__init__(collect_device=collect_device, prefix=prefix)
-
- def process(self, data_batch, data_samples: Sequence[dict]):
- """Process one batch of data samples.
-
- The processed results should be stored in ``self.results``, which will
- be used to computed the metrics when all batches have been processed.
-
- Args:
- data_batch: A batch of data from the dataloader.
- data_samples (Sequence[dict]): A batch of outputs from the model.
- """
- logits = torch.cat(
- [sample['logits'].unsqueeze(0) for sample in data_samples], dim=0)
- labels = data_batch['data']['labels']
- ds_names = data_batch['data_samples']['ds_names']
- chosen_idx = torch.where(labels == 0)
- rejected_idx = torch.where(labels == 1)
- chosen_logits = logits[chosen_idx].cpu()
- rejected_logits = logits[rejected_idx].cpu()
-
- correct = (chosen_logits > rejected_logits).cpu()
- self.results.append({
- 'chosen_logits': chosen_logits,
- 'rejected_logits': rejected_logits,
- 'correct': correct,
- 'ds_names': ds_names
- })
-
- def compute_metrics(self, results: List):
- """Compute the metrics from processed results.
-
- Args:
- results (dict): The processed results of each batch.
-
- Returns:
- Dict: The computed metrics. The keys are the names of the metrics,
- and the values are corresponding results.
- """
- # NOTICE: don't access `self.results` from the method.
- metrics = {}
-
- correct = torch.cat([res['correct'] for res in results])
- chosen_logits = torch.cat([res['chosen_logits'] for res in results])
- rejected_logits = torch.cat(
- [res['rejected_logits'] for res in results])
- ds_names = list(itertools.chain(*[res['ds_names'] for res in results]))
-
- # group by ds_names
- grouped_correct = defaultdict(list)
- grouped_chosen_logits = defaultdict(list)
- grouped_rejected_logits = defaultdict(list)
- for i, ds_name in enumerate(ds_names):
- grouped_correct[ds_name].append(correct[i])
- grouped_chosen_logits[ds_name].append(chosen_logits[i])
- grouped_rejected_logits[ds_name].append(rejected_logits[i])
-
- # print metrics in a rich table
- table = Table(title='Reward Metrics')
- table.add_column('Dataset Name')
- table.add_column('Accuracy')
- table.add_column('Chosen Score')
- table.add_column('Rejected Score')
-
- for ds_name in grouped_correct.keys():
- correct = torch.stack(grouped_correct[ds_name])
- chosen_logits = torch.stack(grouped_chosen_logits[ds_name])
- rejected_logits = torch.stack(grouped_rejected_logits[ds_name])
-
- acc = correct.float().mean()
- metrics[f'accuracy/{ds_name}'] = acc.item()
- metrics[f'chosen_score/{ds_name}'] = chosen_logits.mean().item()
- metrics[f'rejected_score{ds_name}'] = rejected_logits.mean().item()
-
- table.add_row(ds_name, f'{acc:.4f}', f'{chosen_logits.mean():.4f}',
- f'{rejected_logits.mean():.4f}')
-
- console = Console()
- with console.capture() as capture:
- console.print(table, end='')
- print_log('\n' + capture.get(), 'current')
-
- return metrics
diff --git a/code/xtuner/inference_wl.sh b/code/xtuner/inference_wl.sh
deleted file mode 100644
index ee353198819d6daeb124df9791f035ab12e99a68..0000000000000000000000000000000000000000
--- a/code/xtuner/inference_wl.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-
-
-cd /data/qingq/PathVLM/baselines/github/SlideChat/xtuner
-
-CUDA_VISIBLE_DEVICES=2 xtuner test configs/slidechat/stage_2.py --checkpoint /data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth --test_slide_csv /data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv --test_output_csv /data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_my_test_skcm.csv --local_rank 0 --tumor_type SKCM
-
-CUDA_VISIBLE_DEVICES=2 xtuner test configs/slidechat/stage_2.py --checkpoint /data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth --test_slide_csv /data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv --test_output_csv /data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_my_test_read.csv --local_rank 0 --tumor_type READ
-
-CUDA_VISIBLE_DEVICES=2 xtuner test configs/slidechat/stage_2.py --checkpoint /data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth --test_slide_csv /data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv --test_output_csv /data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_my_test_hnsc.csv --local_rank 0 --tumor_type HNSC
\ No newline at end of file
diff --git a/code/xtuner/model/__init__.py b/code/xtuner/model/__init__.py
deleted file mode 100644
index 896615a09b7e71e79f27fd290a1b4df21e9ff83c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/__init__.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import traceback
-
-errors = []
-
-try:
- from .internvl import InternVL_V1_5
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava import LLaVAModel
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_attn import LLaVAModel_Attn
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .sft import SupervisedFinetune
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_dim_reducer import LLaVAModelWithReducer
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_divprune import LLaVAModel as LLaVAModel_DivPrune
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_compressor import LLaVAModel as LLaVAModel_Compressor
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_fusion_compressor import LLaVAModel as LLaVAModel_FusionCompressor
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_acmil import LLaVAModel as LLaVAModel_ACMIL
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .dynamic_llava import DynamicLLaVAQwen25, DynamicQwen2ForCausalLM
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .llava_fewer import LLaVAModel as LLaVAModel_Fewer
-except Exception:
- errors.append(traceback.format_exc())
-
-try:
- from .fastv import Qwen25ModelFastV, Qwen25ForCausalLMFastV
-except Exception:
- errors.append(traceback.format_exc())
-
-if errors:
- print("One or more import errors occurred:\n")
- for err in errors:
- print(err)
-else:
- __all__ = [
- 'SupervisedFinetune', 'LLaVAModel', 'InternVL_V1_5', 'LLaVAModel_Attn', 'LLaVAModelWithReducer',
- 'LLaVAModel_DivPrune', 'LLaVAModel_Compressor', 'LLaVAModel_FusionCompressor', 'LLaVAModel_ACMIL',
- 'DynamicLLaVAQwen25', 'DynamicQwen2ForCausalLM', 'LLaVAModel_Fewer', 'Qwen25ModelFastV',
- 'Qwen25ForCausalLMFastV'
- ]
\ No newline at end of file
diff --git a/code/xtuner/model/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 32d9f7c090be03aa2fefbf5fbc5960fbbb050192..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/internvl.cpython-311.pyc b/code/xtuner/model/__pycache__/internvl.cpython-311.pyc
deleted file mode 100644
index 42bf9d15cdd9a1f267371431031916d8e47546e4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/internvl.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava.cpython-311.pyc b/code/xtuner/model/__pycache__/llava.cpython-311.pyc
deleted file mode 100644
index 128c1b3d9bd7349000ca6e0d0f6e07a041e94b87..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_acmil.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_acmil.cpython-311.pyc
deleted file mode 100644
index 7322bfbdfaa83c7aab9c670a2eb9e89cf5a4b7c1..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_acmil.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_attn.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_attn.cpython-311.pyc
deleted file mode 100644
index 6006e3ee888d7f1f4524121671e85f824a8296e9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_attn.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_compressor.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_compressor.cpython-311.pyc
deleted file mode 100644
index 0bb8fa56f5ca581b3f85529fbed219b0315fcc1e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_compressor.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_dim_reducer.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_dim_reducer.cpython-311.pyc
deleted file mode 100644
index a0c264eedf60681e4c2fd07068c1fef1f8f44c23..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_dim_reducer.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_divprune.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_divprune.cpython-311.pyc
deleted file mode 100644
index fd99dba6f312c11ba84f817dd89751bae3e51f05..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_divprune.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_fewer.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_fewer.cpython-311.pyc
deleted file mode 100644
index b539008767a1d8f4e03dabc91d6142ad7b0b1ec1..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_fewer.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_fusion_compressor.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_fusion_compressor.cpython-311.pyc
deleted file mode 100644
index 21ba5376f82200e5215e5bee2da02cb993d5c4fc..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_fusion_compressor.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_no_longnet.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_no_longnet.cpython-311.pyc
deleted file mode 100644
index b3c168b89f7c1ff412d0b93f460b1f0b9b93371c..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_no_longnet.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_no_longnet_simple_sampler.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_no_longnet_simple_sampler.cpython-311.pyc
deleted file mode 100644
index 44fca5426a8d4806093b6efba49265142c3c39a4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_no_longnet_simple_sampler.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/llava_only_projector.cpython-311.pyc b/code/xtuner/model/__pycache__/llava_only_projector.cpython-311.pyc
deleted file mode 100644
index 70dabf150da6e8b971598aedabcb71c541bb7ab9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/llava_only_projector.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/sft.cpython-311.pyc b/code/xtuner/model/__pycache__/sft.cpython-311.pyc
deleted file mode 100644
index b97fb79f4b68dbf1f0c5ad9ab29818f50019be19..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/sft.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/sparse_token_merge.cpython-311.pyc b/code/xtuner/model/__pycache__/sparse_token_merge.cpython-311.pyc
deleted file mode 100644
index 9cf8d90b38c22a129a6a669b0ac31d97e7597426..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/sparse_token_merge.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/__pycache__/utils.cpython-311.pyc b/code/xtuner/model/__pycache__/utils.cpython-311.pyc
deleted file mode 100644
index 26a689a4541a0ea3db887102e69cdb9a2eb4c4df..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/__pycache__/utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/Attention.py b/code/xtuner/model/architecture/Attention.py
deleted file mode 100644
index 15da7c8a1f9879fbe4d6dbe21fef9252f01abc40..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/Attention.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from architecture.network import Classifier_1fc
-
-class Attention2(nn.Module):
- def __init__(self, L=512, D=128, K=1):
- super(Attention2, self).__init__()
-
- self.L = L
- self.D = D
- self.K = K
-
- self.attention = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh(),
- nn.Linear(self.D, self.K)
- )
-
- def forward(self, x, isNorm=True):
- ## x: N x L
- A = self.attention(x) ## N x K
- A = torch.transpose(A, 1, 0) # KxN
- if isNorm:
- A = F.softmax(A, dim=1) # softmax over N
- return A ### K x N
-
-
-class Attention_Gated(nn.Module):
- def __init__(self, L=512, D=128, K=1):
- super(Attention_Gated, self).__init__()
-
- self.L = L
- self.D = D
- self.K = K
-
- self.attention_V = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh()
- )
-
- self.attention_U = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Sigmoid()
- )
-
- self.attention_weights = nn.Linear(self.D, self.K)
-
- def forward(self, x, isNorm=True):
- ## x: N x L
- A_V = self.attention_V(x) # NxD
- A_U = self.attention_U(x) # NxD
- A = self.attention_weights(A_V * A_U) # NxK
- A = torch.transpose(A, 1, 0) # KxN
-
- if isNorm:
- A = F.softmax(A, dim=1) # softmax over N
-
- return A ### K x N
-
-
-class Attention_with_Classifier(nn.Module):
- def __init__(self, L=512, D=128, K=1, num_cls=2, droprate=0):
- super(Attention_with_Classifier, self).__init__()
- self.attention = Attention_Gated(L, D, K)
- self.classifier = Classifier_1fc(L, num_cls, droprate)
- def forward(self, x): ## x: N x L
- AA = self.attention(x) ## K x N
- afeat = torch.mm(AA, x) ## K x L
- pred = self.classifier(afeat) ## K x num_cls
- return pred
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/__init__.py b/code/xtuner/model/architecture/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/code/xtuner/model/architecture/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/architecture/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index f681bb085433cdbc887f40372247cb21e81e4059..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/architecture/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/__pycache__/emb_position.cpython-311.pyc b/code/xtuner/model/architecture/__pycache__/emb_position.cpython-311.pyc
deleted file mode 100644
index 536dfa67f374c8a641c47c1eb4c4cb93ccbf1c1c..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/architecture/__pycache__/emb_position.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/__pycache__/network.cpython-311.pyc b/code/xtuner/model/architecture/__pycache__/network.cpython-311.pyc
deleted file mode 100644
index 000f8d0157236f6b3795d2552995e099c3b017ab..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/architecture/__pycache__/network.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/__pycache__/nystrom_attention.cpython-311.pyc b/code/xtuner/model/architecture/__pycache__/nystrom_attention.cpython-311.pyc
deleted file mode 100644
index af710a9db281be37c7acf95c4235687d2828e0ac..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/architecture/__pycache__/nystrom_attention.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/__pycache__/transformer.cpython-311.pyc b/code/xtuner/model/architecture/__pycache__/transformer.cpython-311.pyc
deleted file mode 100644
index d2f99e42e89df9c1d3395db08e033a89472ea829..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/architecture/__pycache__/transformer.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/architecture/attmil.py b/code/xtuner/model/architecture/attmil.py
deleted file mode 100644
index 1ce2921ad6aa0d1eda1be7e1af60d158b9f008a8..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/attmil.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torchvision.models as models
-
-def initialize_weights(module):
- for m in module.modules():
- if isinstance(m,nn.Linear):
- # ref from clam
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m,nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
-class Resnet(nn.Module):
- def __init__(self):
- super(Resnet, self).__init__()
-
- self.model = list(models.resnet50(pretrained = True).children())[:-1]
- self.features = nn.Sequential(*self.model)
-
- self.feature_extractor_part2 = nn.Sequential(
- nn.Linear(2048, 4096),
- nn.ReLU(),
- nn.Dropout(p=0.25),
- nn.Linear(4096, 512),
- nn.ReLU(),
- nn.Dropout(p=0.25)
- )
- self.classifier = nn.Linear(512,1)
- initialize_weights(self.feature_extractor_part2)
- initialize_weights(self.classifier)
- def forward(self, x):
- x = self.features(x)
- x = x.view(x.size(0), -1)
- x=self.feature_extractor_part2(x)
- # feat = torch.mean(x,dim=0)
- x1 = self.classifier(x)
- # x2 = torch.mean(x1, dim=0).view(1,-1)
- x2,_ = torch.max(x1, dim=0)
- x2=x2.view(1,-1)
- return x2,x
-class AttentionGated(nn.Module):
- def __init__(self,input_dim=512,act='relu',bias=False,dropout=False):
- super(AttentionGated, self).__init__()
- self.L = 512
- self.D = 128 #128
- self.K = 1
-
- self.feature = [nn.Linear(1024, 512)]
- self.feature += [nn.ReLU()]
- self.feature += [nn.Dropout(0.25)]
- self.feature = nn.Sequential(*self.feature)
-
- self.classifier = nn.Sequential(
- nn.Linear(self.L*self.K, 2),
- )
-
- self.attention_a = [
- nn.Linear(self.L, self.D,bias=bias),
- ]
- if act == 'gelu':
- self.attention_a += [nn.GELU()]
- elif act == 'relu':
- self.attention_a += [nn.ReLU()]
- elif act == 'tanh':
- self.attention_a += [nn.Tanh()]
-
- self.attention_b = [nn.Linear(self.L, self.D,bias=bias),
- nn.Sigmoid()]
-
- if dropout:
- self.attention_a += [nn.Dropout(0.25)]
- self.attention_b += [nn.Dropout(0.25)]
-
- self.attention_a = nn.Sequential(*self.attention_a)
- self.attention_b = nn.Sequential(*self.attention_b)
-
- self.attention_c = nn.Linear(self.D, self.K,bias=bias)
-
- self.apply(initialize_weights)
- def forward(self, x):
- x = self.feature(x.squeeze(0))
-
- a = self.attention_a(x)
- b = self.attention_b(x)
- A = a.mul(b)
- A = self.attention_c(A)
-
- A = torch.transpose(A, -1, -2) # KxN
- A = F.softmax(A, dim=-1) # softmax over N
- x = torch.matmul(A,x)
-
- Y_prob = self.classifier(x)
-
- return Y_prob
-
-class DAttention(nn.Module):
- def __init__(self,n_classes,dropout,act):
- super(DAttention, self).__init__()
- self.L = 512 #512
- self.D = 128 #128
- self.K = 1
- self.feature = [nn.Linear(1024, 512)]
-
- if act.lower() == 'gelu':
- self.feature += [nn.GELU()]
- else:
- self.feature += [nn.ReLU()]
-
- if dropout:
- self.feature += [nn.Dropout(0.25)]
-
- self.feature = nn.Sequential(*self.feature)
-
- self.attention = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh(),
- nn.Linear(self.D, self.K)
- )
- self.classifier = nn.Sequential(
- nn.Linear(self.L*self.K, n_classes),
- )
-
- self.apply(initialize_weights)
- def forward(self, x, return_attn=False,no_norm=False):
- feature = self.feature(x)
-
- # feature = group_shuffle(feature)
- feature = feature.squeeze(0)
- A = self.attention(feature)
- A_ori = A.clone()
- A = torch.transpose(A, -1, -2) # KxN
- A = F.softmax(A, dim=-1) # softmax over N
- M = torch.mm(A, feature) # KxL
- Y_prob = self.classifier(M)
-
- if return_attn:
- if no_norm:
- return Y_prob,A_ori
- else:
- return Y_prob,A
- else:
- return Y_prob
-if __name__ == "__main__":
- x=torch.rand(5,3,64,64).cuda()
- gcnnet=Resnet().cuda()
- Y_prob=gcnnet(x)
- criterion = nn.BCEWithLogitsLoss()
- # loss_max = criterion(Y_prob[1].view(1,-1), label.view(1,-1))
- print(Y_prob)
-
diff --git a/code/xtuner/model/architecture/bmil.py b/code/xtuner/model/architecture/bmil.py
deleted file mode 100644
index ecc5983a36a079fba70f98ac1933eb994e825751..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/bmil.py
+++ /dev/null
@@ -1,461 +0,0 @@
-import math
-import numbers
-import time
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch import Tensor
-from utils.utils import initialize_weights
-from architecture.linear_vdo import LinearVDO, Conv2dVDO
-import numpy as np
-from torch.distributions import kl
-
-EPS_1 = 1e-16
-# EPS_2 = 1e-28
-
-"""
-Attention Network without Gating (2 fc layers)
-args:
- L: input feature dimension
- D: hidden layer dimension
- dropout: whether to use dropout (p = 0.25)
- n_classes: number of classes
-"""
-
-
-class Attn_Net(nn.Module):
-
- def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
- super(Attn_Net, self).__init__()
- self.module = [
- nn.Linear(L, D),
- nn.Tanh()]
-
- if dropout:
- self.module.append(nn.Dropout(0.25))
-
- self.module.append(nn.Linear(D, n_classes))
-
- self.module = nn.Sequential(*self.module)
-
- def forward(self, x):
- return self.module(x), x # N x n_classes
-
-
-"""
-Attention Network with Sigmoid Gating (3 fc layers)
-args:
- L: input feature dimension
- D: hidden layer dimension
- dropout: whether to use dropout (p = 0.25)
- n_classes: number of classes
-"""
-
-
-class Attn_Net_Gated(nn.Module):
- def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
- super(Attn_Net_Gated, self).__init__()
- ard_init = -1.
- self.attention_a = [
- LinearVDO(L, D, ard_init=ard_init),
- nn.Tanh()]
-
- self.attention_b = [LinearVDO(L, D, ard_init=ard_init),
- nn.Sigmoid()]
- if dropout:
- self.attention_a.append(nn.Dropout(0.25))
- self.attention_b.append(nn.Dropout(0.25))
-
- self.attention_a = nn.Sequential(*self.attention_a)
- self.attention_b = nn.Sequential(*self.attention_b)
-
- self.attention_c = LinearVDO(D, n_classes, ard_init=ard_init)
-
- def forward(self, x):
- a = self.attention_a(x)
- b = self.attention_b(x)
- A = a.mul(b)
- A = self.attention_c(A) # N x n_classes
- return A, x
-
-
-class DAttn_Net_Gated(nn.Module):
- def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
- super(DAttn_Net_Gated, self).__init__()
- self.attention_a = [
- nn.Linear(L, D),
- nn.Tanh()]
-
- self.attention_b = [nn.Linear(L, D),
- nn.Sigmoid()]
- if dropout:
- self.attention_a.append(nn.Dropout(0.25))
- self.attention_b.append(nn.Dropout(0.25))
-
- self.attention_a = nn.Sequential(*self.attention_a)
- self.attention_b = nn.Sequential(*self.attention_b)
-
- self.attention_c = nn.Linear(D, n_classes)
-
- def forward(self, x):
- a = self.attention_a(x)
- b = self.attention_b(x)
- A = a.mul(b)
- A = self.attention_c(A) # N x n_classes
- # print(x.shape)
- return A, x
-
-
-class GaussianSmoothing(nn.Module):
- """
- Apply gaussian smoothing on a
- 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
- in the input using a depthwise convolution.
- Arguments:
- channels (int, sequence): Number of channels of the input tensors. Output will
- have this number of channels as well.
- kernel_size (int, sequence): Size of the gaussian kernel.
- sigma (float, sequence): Standard deviation of the gaussian kernel.
- dim (int, optional): The number of dimensions of the data.
- Default value is 2 (spatial).
- """
-
- def __init__(self, channels, kernel_size, sigma, dim=2):
- super(GaussianSmoothing, self).__init__()
- if isinstance(kernel_size, numbers.Number):
- kernel_size = [kernel_size] * dim
- if isinstance(sigma, numbers.Number):
- sigma = [sigma] * dim
-
- # The gaussian kernel is the product of the
- # gaussian function of each dimension.
- kernel = 1
- meshgrids = torch.meshgrid(
- [
- torch.arange(size, dtype=torch.float32)
- for size in kernel_size
- ]
- )
-
- for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
- mean = (size - 1) / 2
- kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
- torch.exp(-((mgrid - mean) / std) ** 2 / 2)
-
- # Make sure sum of values in gaussian kernel equals 1.
- kernel = kernel / torch.sum(kernel)
-
- # Reshape to depthwise convolutional weight
- kernel = kernel.view(1, 1, *kernel.size())
- kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
-
- self.register_buffer('weight', kernel)
- self.groups = channels
-
- if dim == 1:
- self.conv = F.conv1d
- elif dim == 2:
- self.conv = F.conv2d
- elif dim == 3:
- self.conv = F.conv3d
- else:
- raise RuntimeError(
- 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
- )
-
- def forward(self, input):
- """
- Apply gaussian filter to input.
- Arguments:
- input (torch.Tensor): Input to apply gaussian filter on.
- Returns:
- filtered (torch.Tensor): Filtered output.
- """
- # return self.conv(input, weight=self.weight, groups=self.groups, dilation=2)
- return self.conv(input, weight=self.weight, groups=self.groups)
-
-
-class probabilistic_MIL_Bayes_vis(nn.Module):
- def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
- super(probabilistic_MIL_Bayes_vis, self).__init__()
- self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
- size = self.size_dict[size_arg]
- fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
- if dropout:
- fc.append(nn.Dropout(0.25))
- if gate:
- attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
- else:
- attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
- fc.append(attention_net)
- self.attention_net = nn.Sequential(*fc)
- self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
- self.n_classes = n_classes
- self.print_sample_trigger = False
- self.num_samples = 16
- self.temperature = torch.tensor([1.0])
- self.fixed_b = torch.tensor([5.], requires_grad=False)
-
- initialize_weights(self)
- self.top_k = top_k
-
- def reparameterize(self, mu, logvar):
- std = torch.exp(0.5 * logvar)
- eps = torch.randn_like(std)
- return mu + eps * std
-
- def relocate(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.attention_net = self.attention_net.to(device)
- self.classifiers = self.classifiers.to(device)
- self.temperature = self.temperature.to(device)
-
- def forward(self, h, validation=False):
- device = h.device
- # *-*# A, h = self.attention_net(h) # NxK
-
- A, h = self.attention_net(h)
-
- mu = A[:, 0]
- logvar = A[:, 1]
- gaus_samples = self.reparameterize(mu, logvar)
- beta_samples = F.sigmoid(gaus_samples)
- A = beta_samples.unsqueeze(0)
- # print('gaus max: {0:.4f}, gaus min: {1:.4f}.'.format(torch.max(gaus_samples), torch.min(gaus_samples)))
- # print('sample max: {0:.4f}, sample min: {1:.4f}.'.format(torch.max(A), torch.min(A)))
-
- M = torch.mm(A, h) / A.sum()
- logits = self.classifiers(M)
- y_probs = F.softmax(logits, dim=1)
- top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
- top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
- Y_hat = torch.topk(top_instance, 1, dim=1)[1]
- Y_prob = F.softmax(top_instance, dim=1)
- # results_dict = {}
-
- # if return_features:
- # top_features = torch.index_select(h, dim=0, index=top_instance_idx)
- # results_dict.update({'features': top_features})
- return top_instance, Y_prob, Y_hat, y_probs, A
-
-
-class probabilistic_MIL_Bayes_enc(nn.Module):
- def __init__(self, gate=True, size_arg="small", dropout=False, n_classes=2, top_k=1):
- super(probabilistic_MIL_Bayes_enc, self).__init__()
- self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
- size = self.size_dict[size_arg]
- first_transform = nn.Linear(size[0], size[1])
- fc1 = [first_transform, nn.ReLU()]
-
- if dropout:
- fc1.append(nn.Dropout(0.25))
-
- if gate:
- postr_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=2)
- else:
- postr_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
-
- fc1.append(postr_net)
-
- self.postr_net = nn.Sequential(*fc1)
- self.classifiers = LinearVDO(size[1], n_classes, ard_init=-3.)
-
- self.n_classes = n_classes
- self.print_sample_trigger = False
- self.num_samples = 16
- self.temperature = torch.tensor([1.0])
- self.prior_mu = torch.tensor([-5., 0.])
- self.prior_logvar = torch.tensor([-1., 3.])
-
- initialize_weights(self)
- self.top_k = top_k
-
- def reparameterize(self, mu, logvar):
- std = torch.exp(0.5 * logvar)
- eps = torch.randn_like(std)
- return mu + eps * std
-
- def relocate(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # self.attention_net = self.attention_net.to(device)
- self.postr_net = self.postr_net.to(device)
- # self.prior_net = self.prior_net.to(device)
- self.classifiers = self.classifiers.to(device)
- self.temperature = self.temperature.to(device)
- self.prior_mu = self.prior_mu.to(device)
- self.prior_logvar = self.prior_logvar.to(device)
-
- def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
- return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
-
- def forward(self, h, return_features=False, slide_label=None, validation=False):
- device = h.device
- # *-*# A, h = self.attention_net(h) # NxK
-
- param, h = self.postr_net(h)
-
- mu = param[:, 0]
- logvar = param[:, 1]
- gaus_samples = self.reparameterize(mu, logvar)
- beta_samples = F.sigmoid(gaus_samples)
- A = beta_samples.unsqueeze(0)
-
- if not validation:
- mu_pr = self.prior_mu[slide_label.item()].expand(h.shape[0])
- logvar_pr = self.prior_logvar[slide_label.item()]
- kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
- else:
- kl_div = None
-
- M = torch.mm(A, h) / A.sum()
-
- logits = self.classifiers(M)
-
- y_probs = F.softmax(logits, dim=1)
- top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
- top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
- Y_hat = torch.topk(top_instance, 1, dim=1)[1]
- Y_prob = F.softmax(top_instance, dim=1)
- results_dict = {}
-
- if return_features:
- top_features = torch.index_select(h, dim=0, index=top_instance_idx)
- results_dict.update({'features': top_features})
- if not validation:
- return top_instance, Y_prob, Y_hat, kl_div, y_probs, A
- else:
- return top_instance, Y_prob, Y_hat, y_probs, A
-
-
-
-class probabilistic_MIL_Bayes_spvis(nn.Module):
- def __init__(self, conf, size_arg="small", top_k=1):
- super(probabilistic_MIL_Bayes_spvis, self).__init__()
-
- # self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384]}
- self.size_dict = {"small": [conf.feat_d, 512, 256], "big": [conf.feat_d, 512, 384]}
- size = self.size_dict[size_arg]
-
- ard_init = -4.
- self.linear1 = nn.Linear(size[0], size[1])
- self.linear2a = LinearVDO(size[1], size[2], ard_init=ard_init)
- self.linear2b = LinearVDO(size[1], size[2], ard_init=ard_init)
- self.linear3 = LinearVDO(size[2], 2, ard_init=ard_init)
-
- self.gaus_smoothing = GaussianSmoothing(1, 3, 0.5)
-
- self.classifiers = LinearVDO(size[1], conf.n_class, ard_init=-3.)
-
- self.dp_0 = nn.Dropout(0.25)
- self.dp_a = nn.Dropout(0.25)
- self.dp_b = nn.Dropout(0.25)
-
- self.prior_mu = torch.tensor([-5., 0.])
- self.prior_logvar = torch.tensor([-1., 3.])
-
- initialize_weights(self)
- self.top_k = top_k
- self.patch_size = conf.patch_size
-
- def reparameterize(self, mu, logvar):
- std = torch.exp(0.5 * logvar)
- eps = torch.randn_like(std)
- return mu + eps * std
-
- def kl_logistic_normal(self, mu_pr, mu_pos, logvar_pr, logvar_pos):
- return (logvar_pr - logvar_pos) / 2. + (logvar_pos ** 2 + (mu_pr - mu_pos) ** 2) / (2. * logvar_pr ** 2) - 0.5
-
- def relocate(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- self.linear1 = self.linear1.to(device)
- self.linear2a = self.linear2a.to(device)
- self.linear2b = self.linear2b.to(device)
- self.linear3 = self.linear3.to(device)
-
- self.dp_0 = self.dp_0.to(device)
- self.dp_a = self.dp_a.to(device)
- self.dp_b = self.dp_b.to(device)
- self.gaus_smoothing = self.gaus_smoothing.to(device)
-
- self.prior_mu = self.prior_mu.to(device)
- self.prior_logvar = self.prior_logvar.to(device)
-
- self.classifiers = self.classifiers.to(device)
-
- def forward(self, h, coords, height, width, slide_label=None, validation=False):
- h = h[0]
- device = h.device
- h = F.relu(self.dp_0(self.linear1(h)))
-
- feat_a = self.dp_a(torch.sigmoid(self.linear2a(h)))
- feat_b = self.dp_b(torch.tanh(self.linear2b(h)))
- feat = feat_a.mul(feat_b)
- params = self.linear3(feat)
-
- coords = coords // self.patch_size
- asign = lambda coord: coord[:, 0] + coord[:, 1] * (width // self.patch_size)
- coords = asign(coords)
- coords = torch.from_numpy(coords).to(device)
-
- mu = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
- logvar = torch.zeros([1, (height // self.patch_size + 1) * (width // self.patch_size + 1)]).to(device)
-
- mu[:, coords.long()] = params[:, 0]
- logvar[:, coords.long()] = params[:, 1]
-
- mu = mu.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
- logvar = logvar.view(1, height // self.patch_size + 1, width // self.patch_size + 1)
-
- if not validation:
- mu_pr = self.prior_mu[slide_label.item()].expand_as(mu)
- logvar_pr = self.prior_logvar[slide_label.item()]
- kl_div = self.kl_logistic_normal(mu_pr, mu, logvar_pr, logvar)
- else:
- kl_div = None
-
- # # no branch
- mu = F.pad(mu, (1, 1, 1, 1), mode='constant', value=0)
- mu = torch.unsqueeze(mu, dim=0)
- mu = self.gaus_smoothing(mu)
-
- gaus_samples = self.reparameterize(mu, logvar)
- gaus_samples = torch.squeeze(gaus_samples, dim=0)
-
- A = F.sigmoid(gaus_samples)
- A = A.view(1, -1)
-
- patch_A = torch.index_select(A, dim=1, index=coords)
- M = torch.mm(patch_A, h) / patch_A.sum()
-
- logits = self.classifiers(M)
-
- y_probs = F.softmax(logits, dim=1)
- top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
- top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
- Y_hat = torch.topk(top_instance, 1, dim=1)[1]
- Y_prob = F.softmax(top_instance, dim=1)
-
- if not validation:
- return top_instance, Y_prob, Y_hat, kl_div, y_probs, patch_A.view((1, -1))
- else:
- return top_instance, Y_prob, Y_hat, y_probs, patch_A.view((1, -1))
-
-
-def get_ard_reg_vdo(module, reg=0):
- """
- :param module: model to evaluate ard regularization for
- :param reg: auxilary cumulative variable for recursion
- :return: total regularization for module
- """
- if isinstance(module, LinearVDO) or isinstance(module, Conv2dVDO): return reg + module.get_reg()
- if hasattr(module, 'children'): return reg + sum([get_ard_reg_vdo(submodule) for submodule in module.children()])
- return reg
-
-
-bMIL_model_dict = {
- 'vis': probabilistic_MIL_Bayes_vis,
- 'enc': probabilistic_MIL_Bayes_enc,
- 'spvis': probabilistic_MIL_Bayes_spvis,
-}
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/clam.py b/code/xtuner/model/architecture/clam.py
deleted file mode 100644
index daa37dea8e6ab6f10a2a288f619a526ca5021126..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/clam.py
+++ /dev/null
@@ -1,282 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from utils.utils import initialize_weights, softmax_one
-import numpy as np
-
-"""
-Attention Network without Gating (2 fc layers)
-args:
- L: input feature dimension
- D: hidden layer dimension
- dropout: whether to use dropout (p = 0.25)
- n_classes: number of classes
-"""
-
-
-class Attn_Net(nn.Module):
-
- def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
- super(Attn_Net, self).__init__()
- self.module = [
- nn.Linear(L, D),
- nn.Tanh()]
-
- if dropout:
- self.module.append(nn.Dropout(0.25))
-
- self.module.append(nn.Linear(D, n_classes))
-
- self.module = nn.Sequential(*self.module)
-
- def forward(self, x):
- return self.module(x), x # N x n_classes
-
-
-"""
-Attention Network with Sigmoid Gating (3 fc layers)
-args:
- L: input feature dimension
- D: hidden layer dimension
- dropout: whether to use dropout (p = 0.25)
- n_classes: number of classes
-"""
-
-
-class Attn_Net_Gated(nn.Module):
- def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
- super(Attn_Net_Gated, self).__init__()
- self.attention_a = [
- nn.Linear(L, D),
- nn.Tanh()]
-
- self.attention_b = [nn.Linear(L, D),
- nn.Sigmoid()]
- if dropout:
- self.attention_a.append(nn.Dropout(0.25))
- self.attention_b.append(nn.Dropout(0.25))
-
- self.attention_a = nn.Sequential(*self.attention_a)
- self.attention_b = nn.Sequential(*self.attention_b)
-
- self.attention_c = nn.Linear(D, n_classes)
-
- def forward(self, x):
- a = self.attention_a(x)
- b = self.attention_b(x)
- A = a.mul(b)
- A = self.attention_c(A) # N x n_classes
- return A, x
-
-
-"""
-args:
- gate: whether to use gated attention network
- size_arg: config for network size
- dropout: whether to use dropout
- k_sample: number of positive/neg patches to sample for instance-level training
- dropout: whether to use dropout (p = 0.25)
- n_classes: number of classes
- instance_loss_fn: loss function to supervise instance-level training
- subtyping: whether it's a subtyping problem
-"""
-
-
-class CLAM_SB(nn.Module):
- def __init__(self, conf, gate=True, size_arg="small", k_sample=8, dropout=True,
- instance_loss_fn=nn.CrossEntropyLoss()):
- super(CLAM_SB, self).__init__()
- n_classes = conf.n_class
- self.size_dict = {"small": [conf.D_feat, conf.D_inner, 128], "big": [conf.D_feat, 512, 384]}
- size = self.size_dict[size_arg]
- fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
- if dropout:
- fc.append(nn.Dropout(0.25))
- if gate:
- attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
- else:
- attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
- fc.append(attention_net)
- self.attention_net = nn.Sequential(*fc)
- self.classifiers = nn.Linear(size[1], n_classes)
- instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
- self.instance_classifiers = nn.ModuleList(instance_classifiers)
- self.k_sample = k_sample
- self.instance_loss_fn = instance_loss_fn
- self.n_classes = n_classes
- self.subtyping = False
- if conf.n_class > 2:
- self.subtyping = True
-
- initialize_weights(self)
-
- def relocate(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.attention_net = self.attention_net.to(device)
- self.classifiers = self.classifiers.to(device)
- self.instance_classifiers = self.instance_classifiers.to(device)
-
- @staticmethod
- def create_positive_targets(length, device):
- return torch.full((length,), 1, device=device).long()
-
- @staticmethod
- def create_negative_targets(length, device):
- return torch.full((length,), 0, device=device).long()
-
- # instance-level evaluation for in-the-class attention branch
- def inst_eval(self, A, h, classifier):
- device = h.device
- if len(A.shape) == 1:
- A = A.view(1, -1)
- top_p_ids = torch.topk(A, self.k_sample)[1][-1]
- top_p = torch.index_select(h, dim=0, index=top_p_ids)
- top_n_ids = torch.topk(-A, self.k_sample, dim=1)[1][-1]
- top_n = torch.index_select(h, dim=0, index=top_n_ids)
- p_targets = self.create_positive_targets(self.k_sample, device)
- n_targets = self.create_negative_targets(self.k_sample, device)
-
- all_targets = torch.cat([p_targets, n_targets], dim=0)
- all_instances = torch.cat([top_p, top_n], dim=0)
- logits = classifier(all_instances)
- all_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1)
- instance_loss = self.instance_loss_fn(logits, all_targets)
- return instance_loss, all_preds, all_targets
-
- # instance-level evaluation for out-of-the-class attention branch
- def inst_eval_out(self, A, h, classifier):
- device = h.device
- if len(A.shape) == 1:
- A = A.view(1, -1)
- top_p_ids = torch.topk(A, self.k_sample)[1][-1]
- top_p = torch.index_select(h, dim=0, index=top_p_ids)
- p_targets = self.create_negative_targets(self.k_sample, device)
- logits = classifier(top_p)
- p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1)
- instance_loss = self.instance_loss_fn(logits, p_targets)
- return instance_loss, p_preds, p_targets
-
- def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
- A, h = self.attention_net(h[0]) # NxK
- A = torch.transpose(A, -1, -2) # KxN
- if attention_only:
- return A
-
-
- A_raw = A
- A = F.softmax(A, dim=-1) # softmax over N
-
- if instance_eval:
- total_inst_loss = 0.0
- all_preds = []
- all_targets = []
- inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze() # binarize label
- for i in range(len(self.instance_classifiers)):
- inst_label = inst_labels[i].item()
- classifier = self.instance_classifiers[i]
- if inst_label == 1: # in-the-class:
- instance_loss, preds, targets = self.inst_eval(A, h, classifier)
- all_preds.extend(preds.cpu().numpy())
- all_targets.extend(targets.cpu().numpy())
- else: # out-of-the-class
- if self.subtyping:
- instance_loss, preds, targets = self.inst_eval_out(A, h, classifier)
- all_preds.extend(preds.cpu().numpy())
- all_targets.extend(targets.cpu().numpy())
- else:
- continue
- total_inst_loss += instance_loss
-
- if self.subtyping:
- total_inst_loss /= len(self.instance_classifiers)
-
- M = torch.mm(A, h)
- logits = self.classifiers(M)
- if instance_eval:
- return logits, total_inst_loss
- else:
- return logits
- # Y_hat = torch.topk(logits, 1, dim=1)[1]
- # Y_prob = F.softmax(logits, dim=1)
- # if instance_eval:
- # results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
- # 'inst_preds': np.array(all_preds)}
- # else:
- # results_dict = {}
- # if return_features:
- # results_dict.update({'features': M})
- # return logits, Y_prob, Y_hat, A_raw, results_dict
-
-
-class CLAM_MB(CLAM_SB):
- def __init__(self, conf, gate=True, size_arg="small", k_sample=8, dropout=True,
- instance_loss_fn=nn.CrossEntropyLoss()):
- nn.Module.__init__(self)
- n_classes = conf.n_class
- self.size_dict = {"small": [conf.D_feat, conf.D_inner, 128], "big": [conf.D_feat, 512, 384]}
- size = self.size_dict[size_arg]
- fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
- if dropout:
- fc.append(nn.Dropout(0.25))
- if gate:
- attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=n_classes)
- else:
- attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=n_classes)
- fc.append(attention_net)
- self.attention_net = nn.Sequential(*fc)
- bag_classifiers = [nn.Linear(size[1], 1) for i in
- range(n_classes)] # use an indepdent linear layer to predict each class
- self.classifiers = nn.ModuleList(bag_classifiers)
- instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
- self.instance_classifiers = nn.ModuleList(instance_classifiers)
- self.k_sample = k_sample
- self.instance_loss_fn = instance_loss_fn
- self.n_classes = n_classes
- self.subtyping = False
- if conf.n_class > 2:
- self.subtyping = True
- initialize_weights(self)
-
- def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
- device = h.device
- h = h[0]
- A, h = self.attention_net(h) # NxK
- A = torch.transpose(A, 1, 0) # KxN
- if attention_only:
- return A
- A_raw = A
- A = softmax_one(A, dim=1) # softmax over N
-
- if instance_eval:
- total_inst_loss = 0.0
- all_preds = []
- all_targets = []
- inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze() # binarize label
- for i in range(len(self.instance_classifiers)):
- inst_label = inst_labels[i].item()
- classifier = self.instance_classifiers[i]
- if inst_label == 1: # in-the-class:
- instance_loss, preds, targets = self.inst_eval(A[i], h, classifier)
- all_preds.extend(preds.cpu().numpy())
- all_targets.extend(targets.cpu().numpy())
- else: # out-of-the-class
- if self.subtyping:
- instance_loss, preds, targets = self.inst_eval_out(A[i], h, classifier)
- all_preds.extend(preds.cpu().numpy())
- all_targets.extend(targets.cpu().numpy())
- else:
- continue
- total_inst_loss += instance_loss
-
- if self.subtyping:
- total_inst_loss /= len(self.instance_classifiers)
-
- M = torch.mm(A, h)
- logits = torch.empty(1, self.n_classes).float().to(device)
- for c in range(self.n_classes):
- logits[0, c] = self.classifiers[c](M[c])
- if instance_eval:
- return logits, total_inst_loss
- else:
- return logits
-
diff --git a/code/xtuner/model/architecture/dsmil.py b/code/xtuner/model/architecture/dsmil.py
deleted file mode 100644
index 701c99224fea6527537bda40aef5bae17256a514..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/dsmil.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.autograd import Variable
-# vpt
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torchvision as tv
-from functools import reduce
-from operator import mul
-
-
-class FCLayer(nn.Module):
- def __init__(self, in_size, out_size=1):
- super(FCLayer, self).__init__()
- self.fc = nn.Sequential(nn.Linear(in_size, out_size))
-
- def forward(self, feats):
- x = self.fc(feats)
- return feats, x
-
-
-class IClassifier(nn.Module):
- def __init__(self, feature_extractor, feature_size, output_class):
- super(IClassifier, self).__init__()
-
- self.feature_extractor = feature_extractor
- self.fc = nn.Linear(feature_size, output_class)
-
- def forward(self, x):
- device = x.device
- feats = self.feature_extractor(x) # N x K
- c = self.fc(feats.view(feats.shape[0], -1)) # N x C
- return feats.view(feats.shape[0], -1), c
-
-
-class BClassifier(nn.Module):
- def __init__(self, conf, dropout_v=0.0, nonlinear=True, passing_v=False,
- confounder_path=False): # K, L, N
- super(BClassifier, self).__init__()
- input_size=conf.D_feat
- output_class=conf.n_class
- if nonlinear:
- self.q = nn.Sequential(nn.Linear(input_size, conf.D_inner), nn.ReLU(), nn.Linear(conf.D_inner, 128), nn.Tanh())
- else:
- self.q = nn.Linear(input_size, conf.D_inner)
- if passing_v:
- self.v = nn.Sequential(
- nn.Dropout(dropout_v),
- nn.Linear(input_size, input_size),
- nn.ReLU()
- )
- else:
- self.v = nn.Identity()
-
- ### 1D convolutional layer that can handle multiple class (including binary)
- self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
-
-
- def forward(self, feats, c): # N x K, N x C
- device = feats.device
- V = self.v(feats) # N x V, unsorted
- Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted
- # handle multiple classes without for loop
- _, m_indices = torch.sort(c, 0,
- descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
- # print(m_indices.shape)
- m_feats = torch.index_select(feats, dim=0,
- index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
- q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
- A = torch.mm(Q, q_max.transpose(0,
- 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
- A = A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)) # normalize attention scores, A in shape N x C,
- A = A.transpose(0, 1)
-
- A_out = A
- A = F.softmax(A, dim=-1)
- B = torch.mm(A, V) # compute bag representation, B in shape C x V
- B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
-
- C = self.fcc(B) # 1 x C x 1
- C = C.view(1, -1)
- return C, A_out, B
-
-
-class MILNet(nn.Module):
- def __init__(self, i_classifier, b_classifier):
- super(MILNet, self).__init__()
- self.i_classifier = i_classifier
- self.b_classifier = b_classifier
-
- def forward(self, x):
- feats, classes = self.i_classifier(x[0])
- # print(feats)
- prediction_bag, A, B = self.b_classifier(feats, classes)
- return classes, prediction_bag, A
diff --git a/code/xtuner/model/architecture/emb_position.py b/code/xtuner/model/architecture/emb_position.py
deleted file mode 100644
index fa8fef7c1cf5e2a6d7d7c8e4a9781ff764fcfee4..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/emb_position.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import torch
-from torch import nn
-import numpy as np
-
-class PPEG(nn.Module):
- def __init__(self, dim=512,k=7,conv_1d=False,bias=True):
- super(PPEG, self).__init__()
- self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (k,1), 1, (k//2,0), groups=dim,bias=bias)
- self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (5,1), 1, (5//2,0), groups=dim,bias=bias)
- self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (3,1), 1, (3//2,0), groups=dim,bias=bias)
-
- def forward(self, x):
- B, N, C = x.shape
-
- # padding
- H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N)))
-
- add_length = H * W - N
- # if add_length >0:
- x = torch.cat([x, x[:,:add_length,:]],dim = 1)
-
- if H < 7:
- H,W = 7,7
- zero_pad = H * W - (N+add_length)
- x = torch.cat([x, torch.zeros((B,zero_pad,C),device=x.device)],dim = 1)
- add_length += zero_pad
-
- # H, W = int(N**0.5),int(N**0.5)
- # cls_token, feat_token = x[:, 0], x[:, 1:]
- # feat_token = x
- cnn_feat = x.transpose(1, 2).view(B, C, H, W)
-
- x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
- x = x.flatten(2).transpose(1, 2)
- # print(add_length)
- if add_length >0:
- x = x[:,:-add_length]
- # x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
- return x
-
-class PEG(nn.Module):
- def __init__(self, dim=512,k=7,bias=True,conv_1d=False):
- super(PEG, self).__init__()
- self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim,bias=bias) if not conv_1d else nn.Conv2d(dim, dim, (k,1), 1, (k//2,0), groups=dim,bias=bias)
-
- def forward(self, x):
- B, N, C = x.shape
-
- # padding
- H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N)))
- add_length = H * W - N
- x = torch.cat([x, x[:,:add_length,:]],dim = 1)
-
- feat_token = x
- cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
- x = self.proj(cnn_feat)+cnn_feat
-
- x = x.flatten(2).transpose(1, 2)
- if add_length >0:
- x = x[:,:-add_length]
-
- # x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
- return x
-
-
-class SINCOS(nn.Module):
- def __init__(self,embed_dim=512):
- super(SINCOS, self).__init__()
- self.embed_dim = embed_dim
- self.pos_embed = self.get_2d_sincos_pos_embed(embed_dim, 8)
- def get_1d_sincos_pos_embed_from_grid(self,embed_dim, pos):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
-
- def get_2d_sincos_pos_embed_from_grid(self,embed_dim, grid):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = self.get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
- def get_2d_sincos_pos_embed(self,embed_dim, grid_size, cls_token=False):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = self.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
- def forward(self, x):
- #B, N, C = x.shape
- B,H,W,C = x.shape
- # # padding
- # H, W = int(np.ceil(np.sqrt(N))), int(np.ceil(np.sqrt(N)))
- # add_length = H * W - N
- # x = torch.cat([x, x[:,:add_length,:]],dim = 1)
-
- # pos_embed = torch.zeros(1, H * W + 1, self.embed_dim)
- # pos_embed = self.get_2d_sincos_pos_embed(pos_embed.shape[-1], int(H), cls_token=True)
- #pos_embed = torch.from_numpy(self.pos_embed).float().unsqueeze(0).to(x.device)
-
- pos_embed = torch.from_numpy(self.pos_embed).float().to(x.device)
-
- # print(pos_embed.size())
- # print(x.size())
- x = x + pos_embed.unsqueeze(1).unsqueeze(1).repeat(1,H,W,1)
-
-
- #x = x + pos_embed[:, 1:, :]
-
- # if add_length >0:
- # x = x[:,:-add_length]
-
- return x
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/ibmil.py b/code/xtuner/model/architecture/ibmil.py
deleted file mode 100644
index 04dae2e35e69d439e453c00f6d65ae1166a92dc5..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/ibmil.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-from architecture.network import Classifier_1fc, DimReduction
-
-class Attention_Gated(nn.Module):
- def __init__(self, L=512, D=128, K=1):
- super(Attention_Gated, self).__init__()
-
- self.L = L
- self.D = D
- self.K = K
-
- self.attention_V = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh()
- )
-
- self.attention_U = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Sigmoid()
- )
-
- self.attention_weights = nn.Linear(self.D, self.K)
-
- def forward(self, x):
- ## x: N x L
- A_V = self.attention_V(x) # NxD
- A_U = self.attention_U(x) # NxD
- A = self.attention_weights(A_V * A_U) # NxK
- A = torch.transpose(A, 1, 0) # KxN
-
-
- return A ### K x N
-
-
-class IBMIL(nn.Module):
- def __init__(self, conf, confounder_dim=128, confounder_merge='cat'):
- super(IBMIL, self).__init__()
- self.confounder_merge = confounder_merge
- assert confounder_merge in ['cat', 'add', 'sub']
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = Attention_Gated(conf.D_inner, 128, 1)
- self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0)
- self.confounder_path = None
- if conf.c_path:
- print('deconfounding')
- self.confounder_path = conf.c_path
- conf_list = []
- for i in conf.c_path:
- conf_list.append(torch.from_numpy(np.load(i)).view(-1, conf.D_inner).float())
- conf_tensor = torch.cat(conf_list, 0)
- conf_tensor_dim = conf_tensor.shape[-1]
- if conf.c_learn:
- self.confounder_feat = nn.Parameter(conf_tensor, requires_grad=True)
- else:
- self.register_buffer("confounder_feat", conf_tensor)
- joint_space_dim = confounder_dim
- dropout_v = 0.5
- self.W_q = nn.Linear(conf.D_inner, joint_space_dim)
- self.W_k = nn.Linear(conf_tensor_dim, joint_space_dim)
- if confounder_merge == 'cat':
- self.classifier = nn.Linear(conf.D_inner + conf_tensor_dim, conf.n_class)
- elif confounder_merge == 'add' or 'sub':
- self.classifier = nn.Linear(conf.D_inner, conf.n_class)
- self.dropout = nn.Dropout(dropout_v)
-
- def forward(self, x):
- x = x[0]
- x = self.dimreduction(x)
- A = self.attention(x) ## K x N
- A = F.softmax(A, dim=1) # softmax over N
- M = torch.mm(A, x) ## K x L
- # x = x.squeeze(0)
-
- # H = self.feature_extractor_part1(x)
- # H = H.view(-1, 50 * 4 * 4)
- # H = self.feature_extractor_part2(H) # NxL
-
- # A = self.attention_1(x)
- # A = self.attention_2(A) # NxK
- # A = self.attention(x) # NxK
- # A = torch.transpose(A, 1, 0) # KxN
- # A = F.softmax(A, dim=1) # softmax over N
- # print('norm')
- # A = F.softmax(A/ torch.sqrt(torch.tensor(x.shape[1])), dim=1) # For Vis
-
- # M = torch.mm(A, x) # KxL
- if self.confounder_path:
- device = M.device
- # bag_q = self.confounder_W_q(M)
- # conf_k = self.confounder_W_k(self.confounder_feat)
- bag_q = self.W_q(M)
- conf_k = self.W_k(self.confounder_feat)
- deconf_A = torch.mm(conf_k, bag_q.transpose(0, 1))
- deconf_A = F.softmax(
- deconf_A / torch.sqrt(torch.tensor(conf_k.shape[1], dtype=torch.float32, device=device)),
- 0) # normalize attention scores, A in shape N x C,
- conf_feats = torch.mm(deconf_A.transpose(0, 1),
- self.confounder_feat) # compute bag representation, B in shape C x V
- if self.confounder_merge == 'cat':
- M = torch.cat((M, conf_feats), dim=1)
- elif self.confounder_merge == 'add':
- M = M + conf_feats
- elif self.confounder_merge == 'sub':
- M = M - conf_feats
- Y_prob = self.classifier(M)
- # Y_hat = torch.ge(Y_prob, 0.5).float()
- if self.confounder_path:
- return Y_prob, M, deconf_A
- else:
- return Y_prob, M, A
-
- # # AUXILIARY METHODS
- # def calculate_classification_error(self, X, Y):
- # Y = Y.float()
- # _, Y_hat, _ = self.forward(X)
- # error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()
- #
- # return error, Y_hat
- #
- # def calculate_objective(self, X, Y):
- # Y = Y.float()
- # Y_prob, _, A = self.forward(X)
- # Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
- # neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) # negative log bernoulli
- #
- # return neg_log_likelihood, A
-
diff --git a/code/xtuner/model/architecture/ilra.py b/code/xtuner/model/architecture/ilra.py
deleted file mode 100644
index fff31c582c6cace90e232b902688ec226c1a8d87..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/ilra.py
+++ /dev/null
@@ -1,157 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import math
-
-
-
-
-"""
-Exploring Low-Rank Property in Multiple Instance Learning for Whole Slide Image Classification
-Jinxi Xiang et al. ICLR 2023
-"""
-
-def initialize_weights(model):
- for m in model.modules():
- if isinstance(m, nn.Linear):
- nn.init.xavier_normal_(m.weight)
- # m.bias.data.zero_()
-
- elif isinstance(m, nn.BatchNorm1d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
-
-class MultiHeadAttention(nn.Module):
- """
- multi-head attention block
- """
-
- def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False, gated=False):
- super(MultiHeadAttention, self).__init__()
- self.dim_V = dim_V
- self.num_heads = num_heads
- self.multihead_attn = nn.MultiheadAttention(dim_V, num_heads)
- self.fc_q = nn.Linear(dim_Q, dim_V)
- self.fc_k = nn.Linear(dim_K, dim_V)
- self.fc_v = nn.Linear(dim_K, dim_V)
- if ln:
- self.ln0 = nn.LayerNorm(dim_V)
- self.ln1 = nn.LayerNorm(dim_V)
- self.fc_o = nn.Linear(dim_V, dim_V)
-
- self.gate = None
- if gated:
- self.gate = nn.Sequential(nn.Linear(dim_Q, dim_V), nn.SiLU())
-
- def forward(self, Q, K):
-
- Q0 = Q
-
- Q = self.fc_q(Q).transpose(0, 1)
- K, V = self.fc_k(K).transpose(0, 1), self.fc_v(K).transpose(0, 1)
-
- A, _ = self.multihead_attn(Q, K, V)
-
- O = (Q + A).transpose(0, 1)
- O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
- O = O + F.relu(self.fc_o(O))
- O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
-
- if self.gate is not None:
- O = O.mul(self.gate(Q0))
-
- return O
-
-
-class GAB(nn.Module):
- """
- equation (16) in the paper
- """
-
- def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
- super(GAB, self).__init__()
- self.latent = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) # low-rank matrix L
-
- nn.init.xavier_uniform_(self.latent)
-
- self.project_forward = MultiHeadAttention(dim_out, dim_in, dim_out, num_heads, ln=ln, gated=True)
- self.project_backward = MultiHeadAttention(dim_in, dim_out, dim_out, num_heads, ln=ln, gated=True)
-
- def forward(self, X):
- """
- This process, which utilizes 'latent_mat' as a proxy, has relatively low computational complexity.
- In some respects, it is equivalent to the self-attention function applied to 'X' with itself,
- denoted as self-attention(X, X), which has a complexity of O(n^2).
- """
- latent_mat = self.latent.repeat(X.size(0), 1, 1)
- H = self.project_forward(latent_mat, X) # project the high-dimensional X into low-dimensional H
- X_hat = self.project_backward(X, H) # recover to high-dimensional space X_hat
-
- return X_hat
-
-
-class NLP(nn.Module):
- """
- To obtain global features for classification, Non-Local Pooling is a more effective method
- than simple average pooling, which may result in degraded performance.
- """
-
- def __init__(self, dim, num_heads, num_seeds, ln=False):
- super(NLP, self).__init__()
- self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
- nn.init.xavier_uniform_(self.S)
- self.mha = MultiHeadAttention(dim, dim, dim, num_heads, ln=ln)
-
- def forward(self, X):
- global_embedding = self.S.repeat(X.size(0), 1, 1)
- ret = self.mha(global_embedding, X)
- return ret
-
-
-class ILRA(nn.Module):
- def __init__(self, num_layers=2, feat_dim=768, n_classes=2, hidden_feat=256, num_heads=8, topk=1, ln=False):
- super().__init__()
- # stack multiple GAB block
- gab_blocks = []
- for idx in range(num_layers):
- block = GAB(dim_in=feat_dim if idx == 0 else hidden_feat,
- dim_out=hidden_feat,
- num_heads=num_heads,
- num_inds=topk,
- ln=ln)
- gab_blocks.append(block)
-
- self.gab_blocks = nn.ModuleList(gab_blocks)
-
- # non-local pooling for classification
- self.pooling = NLP(dim=hidden_feat, num_heads=num_heads, num_seeds=topk, ln=ln)
-
- # classifier
- self.classifier = nn.Linear(in_features=hidden_feat, out_features=n_classes)
-
- initialize_weights(self)
- print(f"ilra2~")
-
- def forward(self, x):
- for block in self.gab_blocks:
- x = block(x)
-
- feat = self.pooling(x)
- logits = self.classifier(feat)
-
- logits = logits.squeeze(1)
- # Y_hat = torch.topk(logits, 1, dim=1)[1]
- # Y_prob = F.softmax(logits, dim=1)
-
- return logits
-
-
-if __name__ == "__main__":
- model = ILRA(feat_dim=1024, n_classes=2, hidden_feat=256, num_heads=8, topk=1)
- num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print(f"num of params: {num_params}")
-
- x = torch.randn((1, 1600, 1024))
- logits, prob, y_hat = model(x)
- print(f"y shape: {logits.shape}")
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/ips_net.py b/code/xtuner/model/architecture/ips_net.py
deleted file mode 100644
index 694b4ab8d73ca02310a2ac44c7f8b105faafbb59..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/ips_net.py
+++ /dev/null
@@ -1,265 +0,0 @@
-import sys
-import math
-
-import torch
-import torch.nn as nn
-from torchvision.models import resnet18, resnet50
-
-from utils.utils import shuffle_batch, shuffle_instance
-from architecture.transformer import Transformer, pos_enc_1d
-from torchvision import transforms
-import timm
-
-class IPSNet(nn.Module):
- """
- Net that runs all the main components:
- patch encoder, IPS, patch aggregator and classification head
- """
-
- def get_conv_patch_enc(self, enc_type, pretrained):
- if enc_type == 'resnet18':
- net = resnet18(pretrained=pretrained)
- net.fc = nn.Identity()
- elif enc_type == 'resnet50':
- net = resnet50(pretrained=pretrained)
- net.fc = nn.Identity()
- elif enc_type == 'vit_b16':
- net = timm.create_model('vit_base_patch16_224', pretrained=True)
- net.head = nn.Identity()
- return net
-
- def get_projector(self, n_chan_in, D):
- return nn.Sequential(
- nn.LayerNorm(n_chan_in, eps=1e-05, elementwise_affine=False),
- nn.Linear(n_chan_in, D),
- nn.BatchNorm1d(D),
- nn.ReLU()
- )
-
- def get_output_layers(self, tasks):
- """
- Create an output layer for each task according to task definition
- """
-
- D = self.D
- n_class = self.n_class
-
- output_layers = nn.ModuleDict()
- for task in tasks.values():
- if task['act_fn'] == 'softmax':
- act_fn = nn.Softmax(dim=-1)
- elif task['act_fn'] == 'sigmoid':
- act_fn = nn.Sigmoid()
-
- layers = [
- nn.Linear(D, n_class),
- act_fn
- ]
- output_layers[task['name']] = nn.Sequential(*layers)
-
- return output_layers
-
- def __init__(self, device, conf):
- super().__init__()
-
- self.device = device
- self.n_class = conf.n_class
- self.M = conf.M
- self.I = conf.I
- self.D = conf.D
- self.use_pos = conf.use_pos
- self.tasks = conf.tasks
- self.shuffle = conf.shuffle
- self.shuffle_style = conf.shuffle_style
- self.is_image = conf.is_image
-
- if self.is_image:
- self.encoder = self.get_conv_patch_enc(conf.enc_type, conf.pretrained)
- else:
- self.encoder = self.get_projector(conf.n_chan_in, self.D)
-
- # Define the multi-head cross-attention transformer
- self.transf = Transformer(conf.n_token, conf.H, conf.D, conf.D_k, conf.D_v,
- conf.D_inner, conf.attn_dropout, conf.dropout)
-
- # Optionally use standard 1d sinusoidal positional encoding
- if conf.use_pos:
- self.pos_enc = pos_enc_1d(conf.D, conf.N).unsqueeze(0).to(device)
- else:
- self.pos_enc = None
-
- # Define an output layer for each task
- self.output_layers = self.get_output_layers(conf.tasks)
-
- # Define transform function
- self.transform = transforms.Compose([transforms.Resize(224),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
-
- def do_shuffle(self, patches, pos_enc):
- """
- Shuffles patches and pos_enc so that patches that have an equivalent score
- are sampled uniformly
- """
-
- shuffle_style = self.shuffle_style
- if shuffle_style == 'batch':
- patches, shuffle_idx = shuffle_batch(patches)
- if torch.is_tensor(pos_enc):
- pos_enc, _ = shuffle_batch(pos_enc, shuffle_idx)
- elif shuffle_style == 'instance':
- patches, shuffle_idx = shuffle_instance(patches, 1)
- if torch.is_tensor(pos_enc):
- pos_enc, _ = shuffle_instance(pos_enc, 1, shuffle_idx)
-
- return patches, pos_enc
-
- def score_and_select(self, emb, emb_pos, M, idx):
- """
- Scores embeddings and selects the top-M embeddings
- """
- D = emb.shape[2]
-
- emb_to_score = emb_pos if torch.is_tensor(emb_pos) else emb
-
- # Obtain scores from transformer
- attn = self.transf.get_scores(emb_to_score) # (B, M+I)
-
- # Get indixes of top-scoring patches
- top_idx = torch.topk(attn, M, dim = -1)[1] # (B, M)
-
- # Update memory buffers
- # Note: Scoring is based on `emb_to_score`, selection is based on `emb`
- mem_emb = torch.gather(emb, 1, top_idx.unsqueeze(-1).expand(-1,-1,D))
- mem_idx = torch.gather(idx, 1, top_idx)
-
- return mem_emb, mem_idx
-
- def get_preds(self, embeddings):
- preds = {}
- for task in self.tasks.values():
- t_name, t_id = task['name'], task['id']
- layer = self.output_layers[t_name]
-
- emb = embeddings[:,t_id]
- preds[t_name] = layer(emb)
-
- return preds
-
- # IPS runs in no-gradient mode
- @torch.no_grad()
- def ips(self, patches):
- """ Iterative Patch Selection """
-
- # Get useful variables
- M = self.M
- I = self.I
- D = self.D
- device = self.device
- shuffle = self.shuffle
- use_pos = self.use_pos
- pos_enc = self.pos_enc
- patch_shape = patches.shape
- B, N = patch_shape[:2]
-
- # Shortcut: IPS not required when memory is larger than total number of patches
- if M >= N:
- # Batchify pos enc
- pos_enc = pos_enc.expand(B, -1, -1) if use_pos else None
- return patches.to(device), pos_enc
-
- # IPS runs in evaluation mode
- if self.training:
- self.encoder.eval()
- self.transf.eval()
-
- # Batchify positional encoding
- if use_pos:
- pos_enc = pos_enc.expand(B, -1, -1)
-
- # Shuffle patches (i.e., randomize when patches obtain identical scores)
- if shuffle:
- patches, pos_enc = self.do_shuffle(patches, pos_enc)
-
- # Init memory buffer
- # Put patches onto GPU in case it is not there yet (lazy loading).
- # `to` will return self in case patches are located on GPU already (eager loading)
- init_patch = patches[:,:M].to(device)
- init_patch = self.transform(init_patch.reshape(-1, *patch_shape[2:]).div(255))
-
- ## Embed
- mem_emb = self.encoder(init_patch)
- mem_emb = mem_emb.view(B, M, -1)
-
- # Init memory indixes in order to select patches at the end of IPS.
- idx = torch.arange(N, dtype=torch.int64, device=device).unsqueeze(0).expand(B, -1)
- mem_idx = idx[:,:M]
-
- # Apply IPS for `n_iter` iterations
- n_iter = math.ceil((N - M) / I)
- for i in range(n_iter):
- # Get next patches
- start_idx = i * I + M
- end_idx = min(start_idx + I, N)
-
- iter_patch = patches[:, start_idx:end_idx].to(device)
- iter_patch = self.transform(iter_patch.reshape(-1, *patch_shape[2:]).div(255))
- iter_idx = idx[:, start_idx:end_idx]
-
- # Embed
- iter_emb = self.encoder(iter_patch)
- iter_emb = iter_emb.view(B, -1, D)
-
- # Concatenate with memory buffer
- all_emb = torch.cat((mem_emb, iter_emb), dim=1)
- all_idx = torch.cat((mem_idx, iter_idx), dim=1)
- # When using positional encoding, also apply it during patch selection
- if use_pos:
- all_pos_enc = torch.gather(pos_enc, 1, all_idx.view(B, -1, 1).expand(-1, -1, D))
- all_emb_pos = all_emb + all_pos_enc
- else:
- all_emb_pos = None
-
- # Select Top-M patches according to cross-attention scores
- mem_emb, mem_idx = self.score_and_select(all_emb, all_emb_pos, M, all_idx)
-
- # Select patches
- n_dim_expand = len(patch_shape) - 2
- mem_patch = torch.gather(patches, 1,
- mem_idx.view(B, -1, *(1,)*n_dim_expand).expand(-1, -1, *patch_shape[2:]).to(patches.device)
- ).to(device)
-
- if use_pos:
- mem_pos = torch.gather(pos_enc, 1, mem_idx.unsqueeze(-1).expand(-1, -1, D))
- else:
- mem_pos = None
-
- # Set components back to training mode
- # Although components of `self` that are relevant for IPS have been set to eval mode,
- # self is still in training mode at training time, i.e., we can use it here.
- if self.training:
- self.encoder.eval()
- self.transf.train()
-
- # Return selected patch and corresponding positional embeddings
- return mem_patch, mem_pos
-
- def forward(self, mem_patch, mem_pos=None):
- """
- After M patches have been selected during IPS, encode and aggregate them.
- The aggregated embedding is input to a classification head.
- """
-
- patch_shape = mem_patch.shape
- B, M = patch_shape[:2]
-
- mem_emb = self.encoder(self.transform(mem_patch.reshape(-1, *patch_shape[2:]).div(255)))
- mem_emb = mem_emb.view(B, M, -1)
-
- if torch.is_tensor(mem_pos):
- mem_emb = mem_emb + mem_pos
-
- image_emb = self.transf(mem_emb)
-
- preds = self.get_preds(image_emb)
-
- return preds
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/lbmil.py b/code/xtuner/model/architecture/lbmil.py
deleted file mode 100644
index 4ebf939a634898a2b52a1dd7056729b9d43d1c63..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/lbmil.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from architecture.network import Classifier_1fc, DimReduction
-
-
-
-class AttentionLayer(nn.Module):
- def __init__(self, dim=512):
- super(AttentionLayer, self).__init__()
- self.dim = dim
-
- def forward(self, features, W_1, b_1):
- out_c = F.linear(features, W_1, b_1)
- out = out_c - out_c.max()
- out = out.exp()
- out = out.sum(1, keepdim=True)
- alpha = out / out.sum(0)
-
- alpha01 = features.size(0) * alpha.expand_as(features)
- context = torch.mul(features, alpha01)
-
- return context, out_c, torch.squeeze(alpha)
-
-class LBMIL(nn.Module):
- def __init__(self, conf, droprate=0):
- super(LBMIL, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = AttentionLayer(conf.D_inner)
- self.classifier = nn.Linear(conf.D_inner, conf.n_class)
-
- def forward(self, x): ## x: N x L
- x = x[0]
- med_feat = self.dimreduction(x)
- out, out_c, alpha = self.attention(med_feat, self.classifier.weight, self.classifier.bias)
- out = out.mean(0, keepdim=True)
-
- y = self.classifier(out)
- return y, out_c, alpha
-
-
-
-
diff --git a/code/xtuner/model/architecture/linear_vdo.py b/code/xtuner/model/architecture/linear_vdo.py
deleted file mode 100644
index d8a610d09578d0f26bbd3924e7b4ecf86a49ba68..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/linear_vdo.py
+++ /dev/null
@@ -1,249 +0,0 @@
-import torch
-from torch import nn
-from torch.nn import Parameter
-import torch.nn.functional as F
-from functools import reduce
-import operator
-
-eps = 1e-8
-
-class LinearVDO(nn.Module):
- """
- Dense layer implementation with weights ARD-prior (arxiv:1701.05369)
- """
-
- def __init__(self, in_features, out_features, bias=True, thresh=3, ard_init=-8.):
- super(LinearVDO, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.weight = Parameter(torch.Tensor(out_features, in_features))
- self.thresh = thresh
- if bias:
- self.bias = Parameter(torch.Tensor(out_features))
- else:
- self.register_parameter('bias', None)
- self.ard_init = ard_init
- self.log_alp = Parameter(torch.Tensor(out_features, in_features), requires_grad=True)
-
- self.reset_parameters()
-
- def forward(self, input):
- """
- Forward with all regularized connections and random activations (Beyesian mode). Typically used for train
- """
- # if self.training == False: return F.linear(input, self.weights_clipped, self.bias)
-
- W = self.weight
- mu = input.matmul(W.t())
-
- eps = 1e-8
- log_alp = self.log_alp
-
- in2 = input * input
- exp_ = torch.exp(log_alp)
- w2 = self.weight * self.weight
-
- var = in2.matmul(((exp_ * w2) + eps).t())
-
- si = torch.sqrt(var)
-
- activation = mu + torch.normal(torch.zeros_like(mu), torch.ones_like(mu)) * si
- return activation + self.bias
-
- @property
- def weights_clipped(self):
- clip_mask = self.get_clip_mask()
- return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight)
-
- def reset_parameters(self):
- self.weight.data.normal_(std=0.01)
- if self.bias is not None:
- self.bias.data.uniform_(0, 0)
- self.log_alp.data = self.ard_init * torch.ones_like(self.log_alp)
-
- @staticmethod
- def clip(tensor, to=10.):
- """
- Shrink all tensor's values to range [-to,to]
- """
- return torch.clamp(tensor, -to, to)
-
- @staticmethod
- def clip_alp(tensor, lwrb=20.):
- """
- Shrink all tensor's values to range [-to,to]
- """
- return torch.clamp(tensor, -lwrb, -eps)
-
- def get_clip_mask(self):
- log_alp = self.clip_alp(self.log_alp)
- return torch.ge(log_alp, self.thresh)
-
- def train(self, mode):
- self.training = mode
- super(LinearVDO, self).train(mode)
-
- def get_reg(self, **kwargs):
- """
- Get weights regularization (KL(q(w)||p(w)) approximation)
- """
- # a flexible reparameterization of variance
-
- k1 = 0.6134
- k2 = 0.2026
- k3 = 0.7126
-
- log_alp = self.log_alp
-
- element_wise_kl = -.5 * torch.log(1 + 1. / (torch.exp(log_alp))) \
- + k1 * torch.exp(-(k2 + k3 * log_alp) ** 2)
-
- sum_kl = element_wise_kl.mean(dim=(1,))
-
- return - sum_kl.sum()
- # return -torch.mean(minus_kl)
-
- def extra_repr(self):
- return 'in_features={}, out_features={}, bias={}'.format(
- self.in_features, self.out_features, self.bias is not None
- )
-
- def get_dropped_params_cnt(self):
- """
- Get number of dropped weights (with log alpha greater than "thresh" parameter)
- :returns (number of dropped weights, number of all weight)
- """
- return self.get_clip_mask().sum().cpu().numpy()
-
- @property
- def log_alpha(self):
- eps = 1e-8
- return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight) + eps)
-
-
-class Conv2dVDO(nn.Conv2d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, ard_init=-1., thresh=3, weight_prob_fwd=True):
- bias = False # Goes to nan if bias = True
- super(Conv2dVDO, self).__init__(in_channels, out_channels, kernel_size, stride,
- padding, dilation, groups, bias)
- self.bias = None
- self.thresh = thresh
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.ard_init = ard_init
- # self.log_sigma2 = Parameter(ard_init * torch.ones_like(self.weight))
- self.log_alp = Parameter(ard_init * torch.ones_like(self.weight), requires_grad=True)
- self.weight_prob_fwd = weight_prob_fwd
-
- # self.log_sigma2 = Parameter(2 * torch.log(torch.abs(self.weight) + eps).clone().detach()+ard_init*torch.ones_like(self.weight))
-
- @staticmethod
- def clip(tensor, to=8):
- """
- Shrink all tensor's values to range [-to,to]
- """
- return torch.clamp(tensor, -to, to)
-
- @staticmethod
- def clip_alp(tensor, lwrb=10.):
- """
- Shrink all tensor's values to range [-to,to]
- """
- return torch.clamp(tensor, -lwrb, -eps)
-
- def set_weight_prob_fwd(self, weight_prob_fwd):
- assert type(weight_prob_fwd) is bool
- self.weight_prob_fwd = weight_prob_fwd
-
- def forward(self, input):
- """
- Forward with all regularized connections and random activations (Beyesian mode). Typically used for train
- """
- if self.training == False and self.weight_prob_fwd == False:
- return F.conv2d(input, self.weights_clipped,
- self.bias, self.stride,
- self.padding, self.dilation, self.groups)
-
- eps = 1e-8
- W = self.weight
- zeros = torch.zeros_like(W)
- clip_mask = self.get_clip_mask()
- conved_mu = F.conv2d(input, W, self.bias, self.stride,
- self.padding, self.dilation, self.groups)
-
- log_alpha = self.log_alp
- # log_alpha = self.log_alpha
-
- conved_si = torch.sqrt(eps + F.conv2d(input * input,
- torch.exp(log_alpha) * W * W, self.bias, self.stride,
- self.padding, self.dilation, self.groups))
-
- conved = conved_mu + \
- conved_si * torch.normal(torch.zeros_like(conved_mu), torch.ones_like(conved_mu))
- return conved
-
- @property
- def weights_clipped(self):
- clip_mask = self.get_clip_mask()
- return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight)
-
- def get_clip_mask(self):
- log_alp = self.clip_alp(self.log_alp)
- # log_alp = self.clip_alp(self.log_alpha)
-
- return torch.ge(log_alp, self.thresh)
-
- def train(self, mode):
- self.training = mode
- super(Conv2dVDO, self).train(mode)
-
- def get_reg(self, **kwargs):
- """
- Get weights regularization (KL(q(w)||p(w)) approximation)
- """
-
- # param 1
- # k1 = 0.792
- # k2 = -0.4826
- # k3 = 0.3451
-
- # param 2
- k1 = 0.6134
- k2 = 0.2026
- k3 = 0.7126
-
- log_alp = self.log_alp
-
- element_wise_kl = -.5 * torch.log(1 + 1./(torch.exp(log_alp))) \
- + k1 * torch.exp(-(k2 + k3 * log_alp) ** 2)
-
- sum_kl = element_wise_kl.mean(dim=(1, 2, 3))
- return - sum_kl.sum()
-
- # log_alp = self.clip_alp(self.log_alp)
- # # log_alp = self.clip_alp(self.log_alpha)
- # # mdkl = k1 * torch.sigmoid(k2 + k3 * log_alp2) - 0.5 * torch.log1p(torch.exp(-log_alp2)) + C
- # minus_kl = .5 * log_alp \
- # + 1.16145124 * torch.exp(log_alp) \
- # - 1.50204118 * torch.exp(log_alp)**2 \
- # + 0.58629921 * torch.exp(log_alp)**3
- #
- # return -torch.sum(minus_kl)
-
- def extra_repr(self):
- return 'in_features={}, out_features={}, bias={}'.format(
- self.in_channels, self.out_channels, self.bias is not None
- )
-
- def get_dropped_params_cnt(self):
- """
- Get number of dropped weights (greater than "thresh" parameter)
- :returns (number of dropped weights, number of all weight)
- """
- return self.get_clip_mask().sum().cpu().numpy()
-
- @property
- def log_alpha(self):
- eps = 1e-8
- return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight) + eps)
diff --git a/code/xtuner/model/architecture/mean_max.py b/code/xtuner/model/architecture/mean_max.py
deleted file mode 100644
index 7d9af509036583db934f42ffc003a6d51cbae3ec..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/mean_max.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch.nn as nn
-
-def initialize_weights(module):
- for m in module.modules():
- if isinstance(m,nn.Linear):
- # ref from clam
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m,nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
-class MeanMIL(nn.Module):
- def __init__(self,n_classes=1,dropout=True,act='relu',test=False):
- super(MeanMIL, self).__init__()
-
- head = [nn.Linear(192,192)]
-
- if act.lower() == 'relu':
- head += [nn.ReLU()]
- elif act.lower() == 'gelu':
- head += [nn.GELU()]
-
- if dropout:
- head += [nn.Dropout(0.25)]
-
- head += [nn.Linear(192,n_classes)]
-
- self.head = nn.Sequential(*head)
-
- self.apply(initialize_weights)
-
- def forward(self,x):
-
- x = self.head(x).mean(axis=1)
- return x
-
-class MaxMIL(nn.Module):
- def __init__(self,n_classes=1,dropout=True,act='relu',test=False):
- super(MaxMIL, self).__init__()
-
- head = [nn.Linear(1024,512)]
-
- if act.lower() == 'relu':
- head += [nn.ReLU()]
- elif act.lower() == 'gelu':
- head += [nn.GELU()]
-
- if dropout:
- head += [nn.Dropout(0.25)]
-
- head += [nn.Linear(512,n_classes)]
- self.head = nn.Sequential(*head)
-
- self.apply(initialize_weights)
-
- def forward(self,x):
- x,_ = self.head(x).max(axis=1)
- return x
diff --git a/code/xtuner/model/architecture/mhim.py b/code/xtuner/model/architecture/mhim.py
deleted file mode 100644
index 81c5bc78438a23189fcb89e514480dc303e9c45f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/mhim.py
+++ /dev/null
@@ -1,300 +0,0 @@
-import torch
-import numpy as np
-from torch import nn
-
-import torch.nn.functional as F
-
-
-class DAttention(nn.Module):
- def __init__(self, n_classes, dropout, act):
- super(DAttention, self).__init__()
- self.L = 512 # 512
- self.D = 128 # 128
- self.K = 1
- self.feature = [nn.Linear(1024, 512)]
-
- if act.lower() == 'gelu':
- self.feature += [nn.GELU()]
- else:
- self.feature += [nn.ReLU()]
-
- if dropout:
- self.feature += [nn.Dropout(0.25)]
-
- self.feature = nn.Sequential(*self.feature)
-
- self.attention = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh(),
- nn.Linear(self.D, self.K)
- )
- self.classifier = nn.Sequential(
- nn.Linear(self.L * self.K, n_classes),
- )
-
- self.apply(initialize_weights)
-
- def forward(self, x, return_attn=False, no_norm=False):
- feature = self.feature(x)
-
- # feature = group_shuffle(feature)
- feature = feature.squeeze(0)
- A = self.attention(feature)
- A_ori = A.clone()
- A = torch.transpose(A, -1, -2) # KxN
- A = F.softmax(A, dim=-1) # softmax over N
- M = torch.mm(A, feature) # KxL
- Y_prob = self.classifier(M)
-
- if return_attn:
- if no_norm:
- return Y_prob, A_ori
- else:
- return Y_prob, A
- else:
- return Y_prob
-
-def initialize_weights(module):
- for m in module.modules():
- if isinstance(m,nn.Linear):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m,nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
-class SoftTargetCrossEntropy_v2(nn.Module):
-
- def __init__(self,temp_t=1.,temp_s=1.):
- super(SoftTargetCrossEntropy_v2, self).__init__()
- self.temp_t = temp_t
- self.temp_s = temp_s
-
- def forward(self, x: torch.Tensor, target: torch.Tensor, mean: bool= True) -> torch.Tensor:
- loss = torch.sum(-F.softmax(target / self.temp_t,dim=-1) * F.log_softmax(x / self.temp_s, dim=-1), dim=-1)
- if mean:
- return loss.mean()
- else:
- return loss
-
-class MHIM(nn.Module):
- def __init__(self, mlp_dim=512,mask_ratio=0,n_classes=2,temp_t=1.,temp_s=1.,dropout=0.25,act='relu',select_mask=True,select_inv=False,msa_fusion='vote',mask_ratio_h=0.,mrh_sche=None,mask_ratio_hr=0.,mask_ratio_l=0.,da_act='gelu',baseline='selfattn',head=8,attn_layer=0):
- super(MHIM, self).__init__()
-
- self.mask_ratio = mask_ratio
- self.mask_ratio_h = mask_ratio_h
- self.mask_ratio_hr = mask_ratio_hr
- self.mask_ratio_l = mask_ratio_l
- self.select_mask = select_mask
- self.select_inv = select_inv
- self.msa_fusion = msa_fusion
- self.mrh_sche = mrh_sche
- self.attn_layer = attn_layer
-
- self.patch_to_emb = [nn.Linear(1024, 512)]
-
- if act.lower() == 'relu':
- self.patch_to_emb += [nn.ReLU()]
- elif act.lower() == 'gelu':
- self.patch_to_emb += [nn.GELU()]
-
- self.dp = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
-
- self.patch_to_emb = nn.Sequential(*self.patch_to_emb)
-
- self.online_encoder = DAttention(mlp_dim,da_act)
-
- self.predictor = nn.Linear(mlp_dim,n_classes)
-
- self.temp_t = temp_t
- self.temp_s = temp_s
-
- self.cl_loss = SoftTargetCrossEntropy_v2(self.temp_t,self.temp_s)
-
- self.predictor_cl = nn.Identity()
- self.target_predictor = nn.Identity()
-
- self.apply(initialize_weights)
-
- def select_mask_fn(self,ps,attn,largest,mask_ratio,mask_ids_other=None,len_keep_other=None,cls_attn_topk_idx_other=None,random_ratio=1.,select_inv=False):
- ps_tmp = ps
- mask_ratio_ori = mask_ratio
- mask_ratio = mask_ratio / random_ratio
- if mask_ratio > 1:
- random_ratio = mask_ratio_ori
- mask_ratio = 1.
-
- # print(attn.size())
- if mask_ids_other is not None:
- if cls_attn_topk_idx_other is None:
- cls_attn_topk_idx_other = mask_ids_other[:,len_keep_other:].squeeze()
- ps_tmp = ps - cls_attn_topk_idx_other.size(0)
- if len(attn.size()) > 2:
- if self.msa_fusion == 'mean':
- _,cls_attn_topk_idx = torch.topk(attn,int(np.ceil((ps_tmp*mask_ratio)) // attn.size(1)),largest=largest)
- cls_attn_topk_idx = torch.unique(cls_attn_topk_idx.flatten(-3,-1))
- elif self.msa_fusion == 'vote':
- vote = attn.clone()
- vote[:] = 0
-
- _,idx = torch.topk(attn,k=int(np.ceil((ps_tmp*mask_ratio))),sorted=False,largest=largest)
- mask = vote.clone()
- mask = mask.scatter_(2,idx,1) == 1
- vote[mask] = 1
- vote = vote.sum(dim=1)
- _,cls_attn_topk_idx = torch.topk(vote,k=int(np.ceil((ps_tmp*mask_ratio))),sorted=False)
- # print(cls_attn_topk_idx.size())
- cls_attn_topk_idx = cls_attn_topk_idx[0]
- else:
- k = int(np.ceil((ps_tmp*mask_ratio)))
- _,cls_attn_topk_idx = torch.topk(attn,k,largest=largest)
- cls_attn_topk_idx = cls_attn_topk_idx.squeeze(0)
-
- # randomly
- if random_ratio < 1.:
- random_idx = torch.randperm(cls_attn_topk_idx.size(0),device=cls_attn_topk_idx.device)
-
- cls_attn_topk_idx = torch.gather(cls_attn_topk_idx,dim=0,index=random_idx[:int(np.ceil((cls_attn_topk_idx.size(0)*random_ratio)))])
-
-
- # concat other masking idx
- if mask_ids_other is not None:
- cls_attn_topk_idx = torch.cat([cls_attn_topk_idx,cls_attn_topk_idx_other]).unique()
-
- # if cls_attn_topk_idx is not None:
- len_keep = ps - cls_attn_topk_idx.size(0)
- a = set(cls_attn_topk_idx.tolist())
- b = set(list(range(ps)))
- mask_ids = torch.tensor(list(b.difference(a)),device=attn.device)
- if select_inv:
- mask_ids = torch.cat([cls_attn_topk_idx,mask_ids]).unsqueeze(0)
- len_keep = ps - len_keep
- else:
- mask_ids = torch.cat([mask_ids,cls_attn_topk_idx]).unsqueeze(0)
-
- return len_keep,mask_ids
-
- def get_mask(self,ps,i,attn,mrh=None):
- if attn is not None and isinstance(attn,(list,tuple)):
- if self.attn_layer == -1:
- attn = attn[1]
- else:
- attn = attn[self.attn_layer]
- else:
- attn = attn
-
- # random mask
- if attn is not None and self.mask_ratio > 0.:
- len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio,select_inv=self.select_inv,random_ratio=0.001)
- else:
- len_keep,mask_ids = ps,None
-
- # low attention mask
- if attn is not None and self.mask_ratio_l > 0.:
- if mask_ids is None:
- len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio_l,select_inv=self.select_inv)
- else:
- cls_attn_topk_idx_other = mask_ids[:,:len_keep].squeeze() if self.select_inv else mask_ids[:,len_keep:].squeeze()
- len_keep,mask_ids = self.select_mask_fn(ps,attn,False,self.mask_ratio_l,select_inv=self.select_inv,mask_ids_other=mask_ids,len_keep_other=ps,cls_attn_topk_idx_other = cls_attn_topk_idx_other)
-
- # high attention mask
- mask_ratio_h = self.mask_ratio_h
- if self.mrh_sche is not None:
- mask_ratio_h = self.mrh_sche[i]
- if mrh is not None:
- mask_ratio_h = mrh
- if mask_ratio_h > 0. :
- # mask high conf patch
- if mask_ids is None:
- len_keep,mask_ids = self.select_mask_fn(ps,attn,largest=True,mask_ratio=mask_ratio_h,len_keep_other=ps,random_ratio=self.mask_ratio_hr,select_inv=self.select_inv)
- else:
- cls_attn_topk_idx_other = mask_ids[:,:len_keep].squeeze() if self.select_inv else mask_ids[:,len_keep:].squeeze()
-
- len_keep,mask_ids = self.select_mask_fn(ps,attn,largest=True,mask_ratio=mask_ratio_h,mask_ids_other=mask_ids,len_keep_other=ps,cls_attn_topk_idx_other = cls_attn_topk_idx_other,random_ratio=self.mask_ratio_hr,select_inv=self.select_inv)
-
- return len_keep,mask_ids
-
- @torch.no_grad()
- def forward_teacher(self,x,return_attn=False):
-
- x = self.patch_to_emb(x)
- x = self.dp(x)
-
- if return_attn:
- x,attn = self.online_encoder(x,return_attn=True)
- else:
- x = self.online_encoder(x)
- attn = None
-
- return x,attn
-
- @torch.no_grad()
- def forward_test(self,x,return_attn=False,no_norm=False):
- x = self.patch_to_emb(x)
- x = self.dp(x)
-
- if return_attn:
- x,a = self.online_encoder(x,return_attn=True,no_norm=no_norm)
- else:
- x = self.online_encoder(x)
- x = self.predictor(x)
-
- if return_attn:
- return x,a
- else:
- return x
-
- def pure(self,x,return_attn=False):
- x = self.patch_to_emb(x)
- x = self.dp(x)
- ps = x.size(1)
-
- if return_attn:
- x,attn = self.online_encoder(x,return_attn=True)
- else:
- x = self.online_encoder(x)
-
- x = self.predictor(x)
-
- if self.training:
- if return_attn:
- return x, 0, ps,ps,attn
- else:
- return x, 0, ps,ps
- else:
- if return_attn:
- return x,attn
- else:
- return x
-
- def forward_loss(self, student_cls_feat, teacher_cls_feat):
- if teacher_cls_feat is not None:
- cls_loss = self.cl_loss(student_cls_feat,teacher_cls_feat.detach())
- else:
- cls_loss = 0.
-
- return cls_loss
-
- def forward(self, x,attn=None,teacher_cls_feat=None,i=None):
- x = self.patch_to_emb(x)
- x = self.dp(x)
-
- ps = x.size(1)
-
- # get mask
- if self.select_mask:
- len_keep,mask_ids = self.get_mask(ps,i,attn)
- else:
- len_keep,mask_ids = ps,None
-
- # forward online network
- student_cls_feat= self.online_encoder(x,len_keep=len_keep,mask_ids=mask_ids,mask_enable=True)
-
- # prediction
- student_logit = self.predictor(student_cls_feat)
-
- # cl loss
- cls_loss= self.forward_loss(student_cls_feat=student_cls_feat,teacher_cls_feat=teacher_cls_feat)
-
- return student_logit, cls_loss,ps,len_keep
diff --git a/code/xtuner/model/architecture/network.py b/code/xtuner/model/architecture/network.py
deleted file mode 100644
index fc158a1dd6537d4fa861308f7f5558ffa46f4544..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/network.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-
-
-class Classifier_1fc(nn.Module):
- def __init__(self, n_channels, n_classes, droprate=0.0):
- super(Classifier_1fc, self).__init__()
- self.fc = nn.Linear(n_channels, n_classes)
- self.droprate = droprate
- if self.droprate != 0.0:
- self.dropout = torch.nn.Dropout(p=self.droprate)
-
- def forward(self, x):
-
- if self.droprate != 0.0:
- x = self.dropout(x)
- x = self.fc(x)
- return x
-
-
-class residual_block(nn.Module):
- def __init__(self, nChn=512):
- super(residual_block, self).__init__()
- self.block = nn.Sequential(
- nn.Linear(nChn, nChn, bias=False),
- nn.ReLU(inplace=True),
- nn.Linear(nChn, nChn, bias=False),
- nn.ReLU(inplace=True),
- )
- def forward(self, x):
- tt = self.block(x)
- x = x + tt
- return x
-
-
-class DimReduction(nn.Module):
- def __init__(self, n_channels, m_dim=512, numLayer_Res=0):
- super(DimReduction, self).__init__()
- self.fc1 = nn.Linear(n_channels, m_dim, bias=False)
- self.relu1 = nn.ReLU(inplace=True)
- self.numRes = numLayer_Res
-
- self.resBlocks = []
- for ii in range(numLayer_Res):
- self.resBlocks.append(residual_block(m_dim))
- self.resBlocks = nn.Sequential(*self.resBlocks)
-
- def forward(self, x):
-
- x = self.fc1(x)
- x = self.relu1(x)
-
- if self.numRes > 0:
- x = self.resBlocks(x)
-
- return x
-
-
-
-class DimReduction1(nn.Module):
- def __init__(self, n_channels, m_dim=512, numLayer_Res=0):
- super(DimReduction1, self).__init__()
- self.fc1 = nn.Linear(n_channels, m_dim)
- self.relu1 = nn.ReLU(inplace=True)
- self.numRes = numLayer_Res
-
- self.resBlocks = []
- for ii in range(numLayer_Res):
- self.resBlocks.append(residual_block(m_dim))
- self.resBlocks = nn.Sequential(*self.resBlocks)
-
- def forward(self, x):
- x_ = x
- x = self.fc1(x)
- x = self.relu1(x+x_)
-
- if self.numRes > 0:
- x = self.resBlocks(x)
-
- return x
-
-
-
diff --git a/code/xtuner/model/architecture/nystrom_attention.py b/code/xtuner/model/architecture/nystrom_attention.py
deleted file mode 100644
index ac46bbe45aa259dd6d432338016567c39763bf2f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/nystrom_attention.py
+++ /dev/null
@@ -1,204 +0,0 @@
-from math import ceil
-import torch
-from torch import nn, einsum
-import torch.nn.functional as F
-from einops import rearrange, reduce
-
-# helper functions
-
-def exists(val):
- return val is not None
-
-def moore_penrose_iter_pinv(x, iters = 6):
- device = x.device
-
- abs_x = torch.abs(x)
- col = abs_x.sum(dim = -1)
- row = abs_x.sum(dim = -2)
- z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))
-
- I = torch.eye(x.shape[-1], device = device)
- I = rearrange(I, 'i j -> () i j')
-
- for _ in range(iters):
- xz = x @ z
- z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz)))))
-
- return z
-
-# main attention class
-class NystromAttention(nn.Module):
- def __init__(
- self,
- dim,
- dim_head = 64,
- heads = 8,
- num_landmarks = 256,
- pinv_iterations = 6,
- residual = True,
- residual_conv_kernel = 33,
- eps = 1e-8,
- dropout = 0.,
- n_token = 1
- ):
- super().__init__()
- self.eps = eps
- inner_dim = heads * dim_head
- self.n_token = n_token
-
- self.num_landmarks = num_landmarks
- self.pinv_iterations = pinv_iterations
-
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
-
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout)
- )
-
- self.residual = residual
- if residual:
- kernel_size = residual_conv_kernel
- padding = residual_conv_kernel // 2
- self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
-
- def forward(self, x, mask = None, return_attn = False):
- b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
-
- # pad so that sequence can be evenly divided into m landmarks
-
- remainder = n % m
- if remainder > 0:
- padding = m - (n % m)
- x = F.pad(x, (0, 0, padding, 0), value = 0)
-
- if exists(mask):
- mask = F.pad(mask, (padding, 0), value = False)
-
- # derive query, keys, values
-
- q, k, v = self.to_qkv(x).chunk(3, dim = -1)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
-
- # set masked positions to 0 in queries, keys, values
-
- if exists(mask):
- mask = rearrange(mask, 'b n -> b () n')
- q, k, v = map(lambda t: t * mask[..., None], (q, k, v))
-
- q = q * self.scale
-
- # generate landmarks by sum reduction, and then calculate mean using the mask
-
- l = ceil(n / m)
- landmark_einops_eq = '... (n l) d -> ... n d'
- q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
- k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
-
- # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
-
- divisor = l
- if exists(mask):
- mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
- divisor = mask_landmarks_sum[..., None] + eps
- mask_landmarks = mask_landmarks_sum > 0
-
- # masked mean (if mask exists)
-
- q_landmarks /= divisor
- k_landmarks /= divisor
-
- # similarities
-
- einops_eq = '... i d, ... j d -> ... i j'
- attn1 = einsum(einops_eq, q, k_landmarks)
- attn2 = einsum(einops_eq, q_landmarks, k_landmarks)
- attn3 = einsum(einops_eq, q_landmarks, k)
-
- # masking
-
- if exists(mask):
- mask_value = -torch.finfo(q.dtype).max
- sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
- sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
- sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
-
- # eq (15) in the paper and aggregate values
-
- attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (attn1, attn2, attn3))
- attn2 = moore_penrose_iter_pinv(attn2, iters)
- out = (attn1 @ attn2) @ (attn3 @ v)
-
- # add depth-wise conv residual of values
- if self.residual:
- out += self.res_conv(v)
-
- # merge and combine heads
-
- out = rearrange(out, 'b h n d -> b n (h d)', h = h)
- out = self.to_out(out)
- out = out[:, -n:]
- if return_attn:
- attn1 = attn1[:,:,:self.n_token] @ attn2
- attn1 = (attn1 @ attn3)
-
- return out, attn1.mean(1)
-
- return out
-
-# transformer
-
-class PreNorm(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
-
- def forward(self, x, **kwargs):
- x = self.norm(x)
- return self.fn(x, **kwargs)
-
-class FeedForward(nn.Module):
- def __init__(self, dim, mult = 4, dropout = 0.):
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, dim * mult),
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(dim * mult, dim)
- )
-
- def forward(self, x):
- return self.net(x)
-
-class Nystromformer(nn.Module):
- def __init__(
- self,
- *,
- dim,
- depth,
- dim_head = 64,
- heads = 8,
- num_landmarks = 256,
- pinv_iterations = 6,
- attn_values_residual = True,
- attn_values_residual_conv_kernel = 33,
- attn_dropout = 0.,
- ff_dropout = 0.
- ):
- super().__init__()
-
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, NystromAttention(dim = dim, dim_head = dim_head, heads = heads, num_landmarks = num_landmarks, pinv_iterations = pinv_iterations, residual = attn_values_residual, residual_conv_kernel = attn_values_residual_conv_kernel, dropout = attn_dropout)),
- PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout))
- ]))
-
- def forward(self, x, mask = None):
- for attn, ff in self.layers:
- x = attn(x, mask = mask) + x
- x = ff(x) + x
- return x
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/transMIL.py b/code/xtuner/model/architecture/transMIL.py
deleted file mode 100644
index 884914e472d774449cc95d11a3fdad60e31fc82f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/transMIL.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
-from nystrom_attention import NystromAttention
-
-
-class TransLayer(nn.Module):
-
- def __init__(self, norm_layer=nn.LayerNorm, dim=512):
- super().__init__()
- self.norm = norm_layer(dim)
- self.attn = NystromAttention(
- dim=dim,
- dim_head=dim // 8,
- heads=8,
- num_landmarks=dim // 2, # number of landmarks
- pinv_iterations=6,
- # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
- residual=True,
- # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
- dropout=0.1
- )
-
- def forward(self, x):
- x = x + self.attn(self.norm(x))
-
- return x
-
-
-class PPEG(nn.Module):
- def __init__(self, dim=512):
- super(PPEG, self).__init__()
- self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
- self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
- self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)
-
- def forward(self, x, H, W):
- B, _, C = x.shape
- cls_token, feat_token = x[:, 0], x[:, 1:]
- cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
- x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
- x = x.flatten(2).transpose(1, 2)
- x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
- return x
-
-
-class TransMIL(nn.Module):
- def __init__(self, conf):
- super(TransMIL, self).__init__()
- self.pos_layer = PPEG(dim=conf.D_inner)
- self._fc1 = nn.Sequential(nn.Linear(conf.D_feat, conf.D_inner), nn.ReLU())
- self.cls_token = nn.Parameter(torch.randn(1, 1, conf.D_inner))
- self.n_classes = conf.n_class
- self.layer1 = TransLayer(dim=conf.D_inner)
- self.layer2 = TransLayer(dim=conf.D_inner)
- self.norm = nn.LayerNorm(conf.D_inner)
- self._fc2 = nn.Linear(conf.D_inner, conf.n_class)
-
- def forward(self, input):
- h = self._fc1(input) # [B, n, 512]
-
- # ---->pad
- H = h.shape[1]
- _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
- add_length = _H * _W - H
- h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, 512]
-
- # ---->cls_token
- B = h.shape[0]
- cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
- h = torch.cat((cls_tokens, h), dim=1)
-
- # ---->Translayer x1
- h = self.layer1(h) # [B, N, 512]
-
- # ---->PPEG
- h = self.pos_layer(h, _H, _W) # [B, N, 512]
-
- # ---->Translayer x2
- h = self.layer2(h) # [B, N, 512]
-
- # ---->cls_token
- h = self.norm(h)[:, 0]
-
- # ---->predict
- logits = self._fc2(h) # [B, n_classes]
- # Y_hat = torch.argmax(logits, dim=1)
- # Y_prob = F.softmax(logits, dim=1)
- # results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
- return logits
-
-
-if __name__ == "__main__":
- data = torch.randn((1, 6000, 1024)).cuda()
- model = TransMIL(n_classes=2).cuda()
- print(model.eval())
- results_dict = model(data=data)
- print(results_dict)
\ No newline at end of file
diff --git a/code/xtuner/model/architecture/transformer.py b/code/xtuner/model/architecture/transformer.py
deleted file mode 100644
index c69e08a0db818c29ce2651e3ca77b41d320d371f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/architecture/transformer.py
+++ /dev/null
@@ -1,420 +0,0 @@
-import math
-import os
-
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-from architecture.network import Classifier_1fc, DimReduction, DimReduction1
-from einops import repeat
-from .nystrom_attention import NystromAttention
-from .emb_position import *
-
-def pos_enc_1d(D, len_seq):
-
- if D % 2 != 0:
- raise ValueError("Cannot use sin/cos positional encoding with "
- "odd dim (got dim={:d})".format(D))
- pe = torch.zeros(len_seq, D)
- position = torch.arange(0, len_seq).unsqueeze(1)
- div_term = torch.exp((torch.arange(0, D, 2, dtype=torch.float) *
- -(math.log(10000.0) / D)))
- pe[:, 0::2] = torch.sin(position.float() * div_term)
- pe[:, 1::2] = torch.cos(position.float() * div_term)
-
- return pe
-
-
-class MLP(nn.Module):
- def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate):
- super(MLP, self).__init__()
- self.fc1 = nn.Linear(input_dim, hidden_dim)
- self.fc2 = nn.Linear(hidden_dim, output_dim)
- self.dropout = nn.Dropout(dropout_rate)
-
- def forward(self, x):
- x = self.fc1(x)
- x = torch.relu(x)
- x = self.dropout(x)
- x = self.fc2(x)
- return x
-
-class MLP_single_layer(nn.Module):
- def __init__(self, input_dim, output_dim):
- super(MLP_single_layer, self).__init__()
- self.fc = nn.Linear(input_dim, output_dim)
-
- def forward(self, x):
- x = self.fc(x)
- return x
-
-class ACMIL_MHA(nn.Module):
- def __init__(self, conf, n_token=1, n_masked_patch=0, mask_drop=0):
- super(ACMIL_MHA, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.sub_attention = nn.ModuleList()
- for i in range(n_token):
- self.sub_attention.append(MutiHeadAttention(conf.D_inner, 8, n_masked_patch=n_masked_patch, mask_drop=mask_drop))
- self.bag_attention = MutiHeadAttention_modify(conf.D_inner, 8)
- self.q = nn.Parameter(torch.zeros((1, n_token, conf.D_inner)))
- nn.init.normal_(self.q, std=1e-6)
- self.n_class = conf.n_class
-
- self.classifier = nn.ModuleList()
- for i in range(n_token):
- self.classifier.append(Classifier_1fc(conf.D_inner, conf.n_class, 0.0))
- self.n_token = n_token
- self.Slide_classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0.0)
-
- def forward(self, input):
- input = self.dimreduction(input)
- q = self.q
- k = input
- v = input
- outputs = []
- attns = []
- for i in range(self.n_token):
- feat_i, attn_i = self.sub_attention[i](q[:, i].unsqueeze(0), k, v)
- outputs.append(self.classifier[i](feat_i))
- attns.append(attn_i)
-
- attns = torch.cat(attns, 1)
- feat_bag = self.bag_attention(v, attns.softmax(dim=-1).mean(1, keepdim=True))
-
- return torch.cat(outputs, dim=0), self.Slide_classifier(feat_bag), attns
-
-
-class MHA(nn.Module):
- def __init__(self, conf):
- super(MHA, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = MutiHeadAttention(conf.D_inner, 8)
- self.q = nn.Parameter(torch.zeros((1, 1, conf.D_inner)))
- nn.init.normal_(self.q, std=1e-6)
- self.n_class = conf.n_class
- self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, 0.0)
-
- def forward(self, input):
- input = self.dimreduction(input)
- q = self.q
- k = input
- v = input
- feat, attn = self.attention(q, k, v)
- output = self.classifier(feat)
-
- return output
-
-
-class MutiHeadAttention(nn.Module):
- """
- An attention layer that allows for downscaling the size of the embedding
- after projection to queries, keys, and values.
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_heads: int,
- downsample_rate: int = 1,
- dropout: float = 0.1,
- n_masked_patch: int = 0,
- mask_drop: float = 0.0
- ) -> None:
- super().__init__()
- self.n_masked_patch = n_masked_patch
- self.mask_drop = mask_drop
- self.embedding_dim = embedding_dim
- self.internal_dim = embedding_dim // downsample_rate
- self.num_heads = num_heads
- assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-
- self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
-
- self.layer_norm = nn.LayerNorm(embedding_dim, eps=1e-6)
- self.dropout = nn.Dropout(dropout)
-
- def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
- b, n, c = x.shape
- x = x.reshape(b, n, num_heads, c // num_heads)
- return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
-
- def _recombine_heads(self, x: Tensor) -> Tensor:
- b, n_heads, n_tokens, c_per_head = x.shape
- x = x.transpose(1, 2)
- return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
-
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
- # Input projections
- q = self.q_proj(q)
- k = self.k_proj(k)
- v = self.v_proj(v)
-
- # Separate into heads
- q = self._separate_heads(q, self.num_heads)
- k = self._separate_heads(k, self.num_heads)
- v = self._separate_heads(v, self.num_heads)
-
- # Attention
- _, _, _, c_per_head = q.shape
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
- attn = attn / math.sqrt(c_per_head)
-
- if self.n_masked_patch > 0 and self.training:
- # Get the indices of the top-k largest values
- b, h, q, c = attn.shape
- n_masked_patch = min(self.n_masked_patch, c)
- _, indices = torch.topk(attn, n_masked_patch, dim=-1)
- indices = indices.reshape(b * h * q, -1)
- rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)]
- masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected]
- random_mask = torch.ones(b*h*q, c).to(attn.device)
- random_mask.scatter_(-1, masked_indices, 0)
- attn = attn.masked_fill(random_mask.reshape(b, h, q, -1) == 0, -1e9)
-
- attn_out = attn
- attn = torch.softmax(attn, dim=-1)
- # Get output
- out1 = attn @ v
- out1 = self._recombine_heads(out1)
- out1 = self.out_proj(out1)
- out1 = self.dropout(out1)
- out1 = self.layer_norm(out1)
-
- return out1[0], attn_out[0]
-
-class MutiHeadAttention_modify(nn.Module):
- """
- An attention layer that allows for downscaling the size of the embedding
- after projection to queries, keys, and values.
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_heads: int,
- downsample_rate: int = 1,
- dropout: float = 0.1,
- ) -> None:
- super().__init__()
- self.embedding_dim = embedding_dim
- self.internal_dim = embedding_dim // downsample_rate
- self.num_heads = num_heads
- assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
-
- self.layer_norm = nn.LayerNorm(embedding_dim, eps=1e-6)
- self.dropout = nn.Dropout(dropout)
-
- def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
- b, n, c = x.shape
- x = x.reshape(b, n, num_heads, c // num_heads)
- return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
-
- def _recombine_heads(self, x: Tensor) -> Tensor:
- b, n_heads, n_tokens, c_per_head = x.shape
- x = x.transpose(1, 2)
- return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
-
- def forward(self, v: Tensor, attn: Tensor) -> Tensor:
- # Input projections
- v = self.v_proj(v)
-
- # Separate into heads
- v = self._separate_heads(v, self.num_heads)
-
- # Get output
- out1 = attn @ v
- out1 = self._recombine_heads(out1)
- out1 = self.out_proj(out1)
- out1 = self.dropout(out1)
- out1 = self.layer_norm(out1)
-
- return out1[0]
-
-
-class Attention_Gated(nn.Module):
- def __init__(self, L=512, D=128, K=1):
- super(Attention_Gated, self).__init__()
-
- self.L = L
- self.D = D
- self.K = K
-
- self.attention_V = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Tanh()
- )
-
- self.attention_U = nn.Sequential(
- nn.Linear(self.L, self.D),
- nn.Sigmoid()
- )
-
- self.attention_weights = nn.Linear(self.D, self.K)
-
- def forward(self, x):
- ## x: N x L
- A_V = self.attention_V(x) # NxD
- A_U = self.attention_U(x) # NxD
- A = self.attention_weights(A_V * A_U) # NxK
- A = torch.transpose(A, 1, 0) # KxN
-
-
- return A ### K x N
-
-
-class ABMIL(nn.Module):
- def __init__(self, conf, D=128, droprate=0):
- super(ABMIL, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = Attention_Gated(conf.D_inner, D, 1)
- self.classifier = Classifier_1fc(conf.D_inner, conf.n_class, droprate)
-
- def forward(self, x): ## x: N x L
- x = x[0]
- med_feat = self.dimreduction(x)
- A = self.attention(med_feat) ## K x N
-
- A_out = A
- A = F.softmax(A, dim=1) # softmax over N
- afeat = torch.mm(A, med_feat) ## K x L
- outputs = self.classifier(afeat)
- return outputs
-
-
-
-
-class ACMIL_GA(nn.Module):
- def __init__(self, conf, D=128, droprate=0, n_token=1, n_masked_patch=0, mask_drop=0):
- super(ACMIL_GA, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = Attention_Gated(conf.D_inner, D, n_token)
- self.classifier = nn.ModuleList()
- for i in range(n_token):
- self.classifier.append(Classifier_1fc(conf.D_inner, conf.n_class, droprate))
- self.n_masked_patch = n_masked_patch
- self.n_token = conf.n_token
- self.Slide_classifier = Classifier_1fc(conf.D_inner, conf.n_class, droprate)
- self.mask_drop = mask_drop
-
-
- def forward(self, x): ## x: N x L
- x = x[0]
- x = self.dimreduction(x)
- A = self.attention(x) ## K x N
-
-
- if self.n_masked_patch > 0 and self.training:
- # Get the indices of the top-k largest values
- k, n = A.shape
- n_masked_patch = min(self.n_masked_patch, n)
- _, indices = torch.topk(A, n_masked_patch, dim=-1)
- rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)]
- masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected]
- random_mask = torch.ones(k, n).to(A.device)
- random_mask.scatter_(-1, masked_indices, 0)
- A = A.masked_fill(random_mask == 0, -1e9)
-
- A_out = A
- A = F.softmax(A, dim=1) # softmax over N
- afeat = torch.mm(A, x) ## K x L
- outputs = []
- for i, head in enumerate(self.classifier):
- outputs.append(head(afeat[i]))
- bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True)
- bag_feat = torch.mm(bag_A, x)
- return torch.stack(outputs, dim=0), self.Slide_classifier(bag_feat), A_out.unsqueeze(0)
-
- def forward_feature(self, x, use_attention_mask=False): ## x: N x L
- x = x[0]
- x = self.dimreduction(x)
- A = self.attention(x) ## K x N
-
-
- if self.n_masked_patch > 0 and use_attention_mask:
- # Get the indices of the top-k largest values
- k, n = A.shape
- n_masked_patch = min(self.n_masked_patch, n)
- _, indices = torch.topk(A, n_masked_patch, dim=-1)
- rand_selected = torch.argsort(torch.rand(*indices.shape), dim=-1)[:,:int(n_masked_patch * self.mask_drop)]
- masked_indices = indices[torch.arange(indices.shape[0]).unsqueeze(-1), rand_selected]
- random_mask = torch.ones(k, n).to(A.device)
- random_mask.scatter_(-1, masked_indices, 0)
- A = A.masked_fill(random_mask == 0, -1e9)
-
- A_out = A
- bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True)
- bag_feat = torch.mm(bag_A, x)
- return bag_feat
-
-
-
-class ACMIL_MHA_NoClassifier(nn.Module):
- def __init__(self, conf, n_token=1, n_masked_patch=0, mask_drop=0):
- super(ACMIL_MHA_NoClassifier, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.sub_attention = nn.ModuleList([
- MutiHeadAttention(conf.D_inner, 8, n_masked_patch=n_masked_patch, mask_drop=mask_drop)
- for _ in range(n_token)
- ])
- self.bag_attention = MutiHeadAttention_modify(conf.D_inner, 8)
- self.q = nn.Parameter(torch.zeros((1, n_token, conf.D_inner)))
- nn.init.normal_(self.q, std=1e-6)
- self.n_token = n_token
-
- def forward(self, input):
- x = self.dimreduction(input)
- q = self.q
- k = x
- v = x
- feats = []
- attns = []
-
- for i in range(self.n_token):
- feat_i, attn_i = self.sub_attention[i](q[:, i].unsqueeze(0), k, v)
- feats.append(feat_i)
- attns.append(attn_i)
-
- attns_tensor = torch.cat(attns, 1)
- bag_feat = self.bag_attention(v, attns_tensor.softmax(dim=-1).mean(1, keepdim=True))
-
- # Return the raw instance features, bag feature, and attention map
- return torch.cat(feats, dim=0), bag_feat, attns_tensor
-
-class ACMIL_GA_NoClassifier(nn.Module):
- def __init__(self, conf, D=128, droprate=0, n_token=1, n_masked_patch=0, mask_drop=0):
- super(ACMIL_GA_NoClassifier, self).__init__()
- self.dimreduction = DimReduction(conf.D_feat, conf.D_inner)
- self.attention = Attention_Gated(conf.D_inner, D, n_token)
- self.n_token = n_token
- self.n_masked_patch = n_masked_patch
- self.mask_drop = mask_drop
-
- def forward(self, x): # x: list or tuple with x[0] = tensor N x L
- x = x[0]
- x_red = self.dimreduction(x)
- A = self.attention(x_red) # K x N
-
- if self.n_masked_patch > 0 and self.training:
- k, n = A.shape
- n_masked = min(self.n_masked_patch, n)
- _, indices = torch.topk(A, n_masked, dim=-1)
- rand_sel = torch.argsort(torch.rand(*indices.shape), dim=-1)
- rand_sel = rand_sel[:, :int(n_masked * self.mask_drop)]
- masked_inds = indices[torch.arange(k).unsqueeze(-1), rand_sel]
- mask = torch.ones_like(A)
- mask.scatter_(-1, masked_inds, 0)
- A = A.masked_fill(mask == 0, -1e9)
-
- A_out = A
- bag_A = F.softmax(A_out, dim=1).mean(0, keepdim=True)
- bag_feat = torch.mm(bag_A, x_red) # 1 x L
-
- # Return token-level attention, bag feature, and raw attention scores
- return A_out, bag_feat, None
-
-# Remove import of Classifier_1fc and related classifier usages.
\ No newline at end of file
diff --git a/code/xtuner/model/dpo.py b/code/xtuner/model/dpo.py
deleted file mode 100644
index faaa43402cb077ca39d9418e778b5bcbede10ace..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dpo.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 # noqa
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-# Copyright (c) OpenMMLab. All rights reserved.
-from copy import deepcopy
-
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from mmengine import MessageHub
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.parallel.sequence import (gather_forward_split_backward,
- get_sequence_parallel_group,
- get_sequence_parallel_world_size,
- split_for_sequence_parallel)
-from .sft import SupervisedFinetune
-
-
-def disable_grad(model):
- # freeze parameters
- parameter_names = [n for n, _ in model.named_parameters()]
- for param_name in parameter_names:
- param = model.get_parameter(param_name)
- param.requires_grad = False
- return model.eval()
-
-
-def create_reference_model(model):
- if is_deepspeed_zero3_enabled():
- raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible '
- 'with `create_reference_model()`. Please instantiate '
- 'your reference model directly with '
- '`AutoCausalLM.from_pretrained()`.')
- ref_model = deepcopy(model)
- ref_model = disable_grad(ref_model)
- return ref_model
-
-
-class DPO(SupervisedFinetune):
- """A general class of DPO and its variants."""
-
- def __init__(self,
- llm,
- ref_llm=None,
- beta=0.1,
- loss_type='sigmoid',
- label_smoothing=0.0,
- **kwargs):
- super().__init__(llm, **kwargs)
- self.loss_type = loss_type
- self.label_smoothing = label_smoothing
- self.beta = beta
-
- if ref_llm is not None:
- ref_llm = self.build_llm_from_cfg(
- ref_llm, kwargs.get('use_varlen_attn', False),
- kwargs.get('max_position_embeddings', None))
- self.ref_llm = disable_grad(ref_llm)
- else:
- self.ref_llm = None if self.use_lora else create_reference_model(
- self.llm)
-
- def _gather_masked_logits(self, logits, labels, mask):
- logits = torch.gather(
- logits.log_softmax(-1), dim=2,
- index=labels.unsqueeze(2)).squeeze(2)
- return logits * mask
-
- def get_logps(
- self,
- policy_logps, # bs, seqlen,vocab_size
- ref_logps, # bs, seqlen,vocab_size
- loss_mask, # bs, seqlen
- ):
- policy_logps = policy_logps[:, :-1].sum(-1)
- ref_logps = ref_logps[:, :-1].sum(-1)
- loss_mask = loss_mask[:, :-1]
-
- if self.loss_type == 'ipo': # average_log_prob
- policy_logps = policy_logps / loss_mask.sum(-1)
- ref_logps = ref_logps / loss_mask.sum(-1)
-
- policy_chosen_logps = policy_logps[::2]
- policy_rejected_logps = policy_logps[1::2]
- reference_chosen_logps = ref_logps[::2]
- reference_rejected_logps = ref_logps[1::2]
- return (policy_chosen_logps, policy_rejected_logps,
- reference_chosen_logps, reference_rejected_logps)
-
- def get_var_len_atten_logps(self, policy_logps, ref_logps, loss_mask,
- cu_seqlens, attention_mask):
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- # unpack sequence
- unpacked_policy_logps = torch.split(policy_logps, seqlens, dim=1)
- unpacked_ref_logps = torch.split(ref_logps, seqlens, dim=1)
- unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
- if attention_mask is not None:
- # It indicate that we pad the original sequence, labels,
- # position_ids and cumulative_len for sequence parallel if the
- # attention_mask is not None.
- # We then need to remove the padded segments.
- assert False in attention_mask
- unpacked_policy_logps = unpacked_policy_logps[:-1]
- unpacked_ref_logps = unpacked_ref_logps[:-1]
- unpacked_loss_mask = unpacked_loss_mask[:-1]
- assert len(unpacked_policy_logps) % 2 == 0
-
- def compute_logps(_logps, _mask):
- _logps = _logps[:, :-1].sum(-1)
- _mask = _mask[:, :-1]
- if self.loss_type == 'ipo':
- _logps /= _mask.sum(-1)
- return _logps
-
- (policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,
- reference_rejected_logps) = [], [], [], []
- for i in range(len(unpacked_policy_logps) // 2):
- chosen = unpacked_policy_logps[2 * i]
- rejected = unpacked_policy_logps[2 * i + 1]
- chosen_ref = unpacked_ref_logps[2 * i]
- rejected_ref = unpacked_ref_logps[2 * i + 1]
- chosen_mask = unpacked_loss_mask[2 * i]
- rejected_mask = unpacked_loss_mask[2 * i + 1]
- policy_chosen_logps.append(compute_logps(chosen, chosen_mask))
- policy_rejected_logps.append(
- compute_logps(rejected, rejected_mask))
- reference_chosen_logps.append(
- compute_logps(chosen_ref, chosen_mask))
- reference_rejected_logps.append(
- compute_logps(rejected_ref, rejected_mask))
-
- return (torch.stack(policy_chosen_logps),
- torch.stack(policy_rejected_logps),
- torch.stack(reference_chosen_logps),
- torch.stack(reference_rejected_logps))
-
- @staticmethod
- def _split_for_sequence_parallel(data):
- # attention mask should not be split
- ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels')
- sp_group = get_sequence_parallel_group()
- for key in ARGS_NEED_TO_SPLIT:
- val = data.get(key, None)
- if val is not None:
- # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
- data[key] = split_for_sequence_parallel(
- val, dim=1, sp_group=sp_group)
- return data
-
- def compute_loss(self, data, data_samples=None):
- # modified from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py # noqa
- # shift labels first and add a dummy label at the end, to support sequence parallel # noqa
- data['labels'] = torch.cat(
- (data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
- dim=1)
- tmp_label = data['labels'].clone()
- tmp_label[tmp_label == 0] = -100
- all_loss_mask = data[
- 'labels'] != -100 # loss mask of all tokens in all sp ranks # noqa
-
- if get_sequence_parallel_world_size() > 1:
- data = self._split_for_sequence_parallel(data)
-
- all_logits = self.llm(**data).logits
- with torch.no_grad():
- if self.ref_llm is None:
- with self.llm.disable_adapter():
- all_ref_logits = self.llm(**data).logits
- else:
- all_ref_logits = self.ref_llm(**data).logits
-
- labels = data['labels']
- labels[labels == -100] = 0
- loss_mask = labels != 0 # loss mask in a single sp rank
- policy_logps = self._gather_masked_logits(all_logits, labels,
- loss_mask)
- ref_logps = self._gather_masked_logits(all_ref_logits, labels,
- loss_mask)
-
- if get_sequence_parallel_world_size() > 1:
- policy_logps = gather_forward_split_backward(
- policy_logps,
- dim=1,
- sp_group=get_sequence_parallel_group(),
- grad_scale='up')
- ref_logps = gather_forward_split_backward(
- ref_logps,
- dim=1,
- sp_group=get_sequence_parallel_group(),
- grad_scale='up')
-
- if not self.use_varlen_attn:
- (policy_chosen_logps, policy_rejected_logps,
- reference_chosen_logps,
- reference_rejected_logps) = self.get_logps(
- policy_logps, ref_logps, all_loss_mask)
- else:
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
- (policy_chosen_logps, policy_rejected_logps,
- reference_chosen_logps,
- reference_rejected_logps) = self.get_var_len_atten_logps(
- policy_logps, ref_logps, all_loss_mask, cu_seqlens,
- data['attention_mask'])
-
- pi_logratios = policy_chosen_logps - policy_rejected_logps
- ref_logratios = reference_chosen_logps - reference_rejected_logps
-
- logits = pi_logratios - ref_logratios
- if self.loss_type == 'sigmoid':
- loss = (-F.logsigmoid(self.beta * logits) *
- (1 - self.label_smoothing) -
- F.logsigmoid(-self.beta * logits) * self.label_smoothing)
- elif self.loss_type == 'robust':
- loss = (-F.logsigmoid(self.beta * logits) *
- (1 - self.label_smoothing) +
- F.logsigmoid(-self.beta * logits) *
- self.label_smoothing) / (1 - 2 * self.label_smoothing)
- elif self.loss_type == 'hinge':
- loss = torch.relu(1 - self.beta * logits)
- elif self.loss_type == 'ipo':
- # eqn (17) of the paper where beta is the regularization
- # parameter for the IPO loss, denoted by tau in the paper. # noqa
- loss = (logits - 1 / (2 * self.beta))**2
- elif self.loss_type == 'kto_pair':
- # eqn (7) of the HALOs paper
- chosen_KL = (policy_chosen_logps -
- reference_chosen_logps).mean().clamp(min=0)
- rejected_KL = (policy_rejected_logps -
- reference_rejected_logps).mean().clamp(min=0)
-
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
- rejected_logratios = \
- policy_rejected_logps - reference_rejected_logps
- # As described in the KTO report, the KL term for chosen (rejected)
- # is estimated using the rejected (chosen) half. # noqa
- loss = torch.cat(
- (
- 1 - F.sigmoid(self.beta *
- (chosen_logratios - rejected_KL)),
- 1 - F.sigmoid(self.beta *
- (chosen_KL - rejected_logratios)),
- ),
- 0,
- )
- elif self.loss_type == 'sppo_hard':
- # In the paper (https://arxiv.org/pdf/2405.00675),
- # SPPO employs a soft probability approach,
- # estimated using the PairRM score. The probability calculation
- # is conducted outside of the trainer class.
- # The version described here is the hard probability version,
- # where P in Equation (4.7) of Algorithm 1 is set to 1 for
- # the winner and 0 for the loser.
- a = policy_chosen_logps - reference_chosen_logps
- b = policy_rejected_logps - reference_rejected_logps
-
- loss = (a - 0.5 / self.beta)**2 + (b + 0.5 / self.beta)**2
- elif self.loss_type == 'nca_pair':
- chosen_rewards = (policy_chosen_logps -
- reference_chosen_logps) * self.beta
- rejected_rewards = (policy_rejected_logps -
- reference_rejected_logps) * self.beta
- loss = (-F.logsigmoid(chosen_rewards) -
- 0.5 * F.logsigmoid(-chosen_rewards) -
- 0.5 * F.logsigmoid(-rejected_rewards))
- else:
- raise ValueError(
- f'Unknown loss type: {self.loss_type}. Should be one of '
- "['sigmoid', 'hinge', 'ipo', 'kto_pair', "
- "'sppo_hard', 'nca_pair', 'robust']")
- # for logging
- chosen_rewards = self.beta * (
- policy_chosen_logps - reference_chosen_logps)
- rejected_rewards = self.beta * (
- policy_rejected_logps - reference_rejected_logps)
- reward_acc = (chosen_rewards > rejected_rewards).float().mean()
-
- loss_dict = {
- 'loss': loss,
- 'chosen_rewards': chosen_rewards.mean(),
- 'rejected_rewards': rejected_rewards.mean(),
- 'reward_acc': reward_acc,
- 'reward_margin': (chosen_rewards - rejected_rewards).mean(),
- }
- return loss_dict
diff --git a/code/xtuner/model/dynamic_llava/__init__.py b/code/xtuner/model/dynamic_llava/__init__.py
deleted file mode 100644
index 64fbb89b8b67748d996972e8e5b4d7056f4c5b87..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .dynamic_llava_qwen import DynamicLLaVAQwen25
-from .dynamic_qwen import DynamicQwen2ForCausalLM
-
-__all__ = [
- "DynamicLLaVAQwen25", "DynamicQwen2ForCausalLM"
-]
\ No newline at end of file
diff --git a/code/xtuner/model/dynamic_llava/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/dynamic_llava/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index aaf00df8ac4432220b451ff8054598f6c33aeb4d..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/dynamic_llava/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/dynamic_llava/__pycache__/cache_utils.cpython-311.pyc b/code/xtuner/model/dynamic_llava/__pycache__/cache_utils.cpython-311.pyc
deleted file mode 100644
index ea381d100914eaf6392b736755d5a7c11aa52a87..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/dynamic_llava/__pycache__/cache_utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/dynamic_llava/__pycache__/custom_transformer_layer.cpython-311.pyc b/code/xtuner/model/dynamic_llava/__pycache__/custom_transformer_layer.cpython-311.pyc
deleted file mode 100644
index 1d5eb900685c565611ce372a1d2f0d54af7ce736..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/dynamic_llava/__pycache__/custom_transformer_layer.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/dynamic_llava/__pycache__/dynamic_llava_qwen.cpython-311.pyc b/code/xtuner/model/dynamic_llava/__pycache__/dynamic_llava_qwen.cpython-311.pyc
deleted file mode 100644
index e6ae7b2fa4a15476586bffdc39b584b87e1dd4e9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/dynamic_llava/__pycache__/dynamic_llava_qwen.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/dynamic_llava/__pycache__/dynamic_qwen.cpython-311.pyc b/code/xtuner/model/dynamic_llava/__pycache__/dynamic_qwen.cpython-311.pyc
deleted file mode 100644
index 4f50d700676149e668d6542b330a661147cc9b71..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/__pycache__/dynamic_qwen.cpython-311.pyc
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6938fadbbd361a8853c15d7f1a6acca7dc09efe9dc40fa3e4183245d71a5030d
-size 100558
diff --git a/code/xtuner/model/dynamic_llava/cache_utils.py b/code/xtuner/model/dynamic_llava/cache_utils.py
deleted file mode 100644
index 4cc873f4e7242e0afefe719bbe147f28a4ca1833..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/cache_utils.py
+++ /dev/null
@@ -1,320 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple
-
-import torch
-
-
-class Cache:
- """
- Base, abstract class for all caches. The actual data structure is specific to each subclass.
- """
-
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
- cache to be created.
-
- Return:
- A tuple containing the updated key and value states.
- """
- raise NotImplementedError("Make sure to implement `update` in a subclass.")
-
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- raise NotImplementedError(
- "Make sure to implement `get_seq_length` in a subclass."
- )
-
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states, if there is any."""
- raise NotImplementedError(
- "Make sure to implement `get_max_length` in a subclass."
- )
-
- def get_usable_length(
- self, new_seq_length: int, layer_idx: Optional[int] = 0
- ) -> int:
- """Given the sequence length of the new inputs, returns the usable length of the cache."""
- # Cache without size limit -> all cache is usable
- # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
- # length, we will need to evict part of the cache (and thus not all cache is usable)
- max_length = self.get_max_length()
- previous_seq_length = self.get_seq_length(layer_idx)
- if max_length is not None and previous_seq_length + new_seq_length > max_length:
- return max_length - new_seq_length
- return previous_seq_length
-
-
-class DynamicCachePlus(Cache):
- """
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
-
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
- `[batch_size, num_heads, seq_len, head_dim]`.
- """
-
- def __init__(self) -> None:
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- self.seen_tokens = (
- 0 # Used in `generate` to keep tally of how many tokens the cache has seen
- )
-
- # ----------------------------------------------------------#
- self.true_cache_length: List[torch.Tensor] = [] # L * [B]
- # ----------------------------------------------------------#
-
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
- """
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self):
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
- else:
- raise KeyError(
- f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
- )
-
- def __iter__(self):
- """
- Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
- keys and values
- """
- for layer_idx in range(len(self)):
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
-
- def __len__(self):
- """
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
- to the number of layers in the model.
- """
- return len(self.key_cache)
-
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- cache_decision: Optional[torch.Tensor] = None, # [B * N]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
-
- Return:
- A tuple containing the updated key and value states.
- """
- B, _, N, _ = key_states.shape
-
- # Update the number of seen tokens
- if layer_idx == 0:
- self.seen_tokens += key_states.shape[-2]
-
- if len(self.key_cache) <= layer_idx:
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
-
- # ----------------------------------------------------------#
- # for prefill stage
- if cache_decision is not None:
- self.true_cache_length.append(cache_decision.sum(dim=-1))
- else:
- self.true_cache_length.append(torch.tensor([N]).repeat(B))
- # ----------------------------------------------------------#
- else:
- # ----------------------------------------------------------#
- if cache_decision is not None:
- if B == 1 and N == 1:
- if cache_decision[0, 0]:
- self.key_cache[layer_idx] = torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- )
- self.value_cache[layer_idx] = torch.cat(
- [self.value_cache[layer_idx], value_states], dim=-2
- )
-
- self.true_cache_length[layer_idx] += N
- else:
- pass
- else: # TODO, efficiency needs to be optimized
- cur_layer_key_cache_batch_list = []
- cur_layer_value_cache_batch_list = []
- for b in range(B):
- cur_keep_indice = cache_decision[b]
- keep_key_states = key_states[
- b, :, cur_keep_indice, :
- ] # H * N * C
- keep_value_states = value_states[
- b, :, cur_keep_indice, :
- ] # H * N * C
- cur_layer_key_cache = torch.cat(
- [
- self.key_cache[layer_idx][
- b, :, : self.true_cache_length[layer_idx][b], :
- ],
- keep_key_states,
- ],
- dim=-2,
- )
- cur_layer_value_cache = torch.cat(
- [
- self.value_cache[layer_idx][
- b, :, : self.true_cache_length[layer_idx][b], :
- ],
- keep_value_states,
- ],
- dim=-2,
- )
- cur_layer_key_cache_batch_list.append(cur_layer_key_cache)
- cur_layer_value_cache_batch_list.append(cur_layer_value_cache)
-
- self.true_cache_length[layer_idx][b] += (
- cache_decision[b].sum().item()
- )
-
- max_cur_layer_kv_cache_length = max(
- cur_layer_key_cache_batch_list[b].shape[-2] for b in range(B)
- )
- for b in range(B):
- cur_len = cur_layer_key_cache_batch_list[b].shape[-2]
- cur_layer_key_cache_batch_list[b] = torch.cat(
- [
- cur_layer_key_cache_batch_list[b],
- torch.zeros(
- (
- cur_layer_key_cache_batch_list[b].shape[0],
- max_cur_layer_kv_cache_length - cur_len,
- cur_layer_key_cache_batch_list[b].shape[-1],
- ),
- dtype=cur_layer_key_cache_batch_list[b].dtype,
- device=cur_layer_key_cache_batch_list[b].device,
- ),
- ],
- dim=-2,
- )
- cur_layer_value_cache_batch_list[b] = torch.cat(
- [
- cur_layer_value_cache_batch_list[b],
- torch.zeros(
- (
- cur_layer_value_cache_batch_list[b].shape[0],
- max_cur_layer_kv_cache_length - cur_len,
- cur_layer_value_cache_batch_list[b].shape[-1],
- ),
- dtype=cur_layer_value_cache_batch_list[b].dtype,
- device=cur_layer_value_cache_batch_list[b].device,
- ),
- ],
- dim=-2,
- )
- self.key_cache[layer_idx] = torch.stack(
- cur_layer_key_cache_batch_list
- )
- self.value_cache[layer_idx] = torch.stack(
- cur_layer_value_cache_batch_list
- )
- else:
- self.key_cache[layer_idx] = torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- )
- self.value_cache[layer_idx] = torch.cat(
- [self.value_cache[layer_idx], value_states], dim=-2
- )
-
- self.true_cache_length[layer_idx] += N
- # ----------------------------------------------------------#
-
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
-
- # ----------------------------------------------------------#
- def get_cache(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if len(self.key_cache) <= layer_idx:
- return key_states, value_states
- else:
- return torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- ), torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
- # ----------------------------------------------------------#
-
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- if len(self.key_cache) <= layer_idx:
- return 0
- return self.key_cache[layer_idx].shape[-2]
-
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
- return None
-
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- for layer_idx in range(len(self.key_cache)):
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
- 0, beam_idx.to(device)
- )
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
- 0, beam_idx.to(device)
- )
-
- # ----------------------------------------------------------#
- def to_legacy_cache(
- self,
- ) -> Tuple[Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]], torch.Tensor]:
- """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
- legacy_cache = ()
- for layer_idx in range(len(self)):
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
- return (legacy_cache, self.true_cache_length)
-
- @classmethod
- def from_legacy_cache(
- cls,
- past_key_values: Optional[
- Tuple[Tuple[Tuple[torch.FloatTensor]], torch.Tensor]
- ] = None,
- ) -> "DynamicCachePlus":
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
- cache = cls()
- if past_key_values is not None:
- for layer_idx in range(len(past_key_values[0])):
- key_states, value_states = past_key_values[0][layer_idx]
- cache.update(key_states, value_states, layer_idx)
- cache.true_cache_length = past_key_values[1]
- return cache
-
- # ----------------------------------------------------------#
diff --git a/code/xtuner/model/dynamic_llava/custom_transformer_layer.py b/code/xtuner/model/dynamic_llava/custom_transformer_layer.py
deleted file mode 100644
index f0e5d34a43f41234cfed8c60682299db275a22e4..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/custom_transformer_layer.py
+++ /dev/null
@@ -1,379 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import Optional, Tuple
-import collections.abc
-from itertools import repeat
-from functools import partial
-
-from torch.jit import Final
-
-# use torch.scaled_dot_product_attention where possible
-_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
-if "TIMM_FUSED_ATTN" in os.environ:
- _USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"])
-else:
- _USE_FUSED_ATTN = (
- 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
- )
-
-# Set to True if exporting a model with Same padding via ONNX
-_EXPORTABLE = False
-
-
-# From PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return tuple(x)
- return tuple(repeat(x, n))
-
- return parse
-
-
-to_2tuple = _ntuple(2)
-
-
-def use_fused_attn(experimental: bool = False) -> bool:
- # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
- if not _HAS_FUSED_ATTN or _EXPORTABLE:
- return False
- if experimental:
- return _USE_FUSED_ATTN > 1
- return _USE_FUSED_ATTN > 0
-
-
-def drop_path(
- x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
-):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
-
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
-
- """
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (
- x.ndim - 1
- ) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
-
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
- def extra_repr(self):
- return f"drop_prob={round(self.drop_prob,3):0.3f}"
-
-
-class Mlp(nn.Module):
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- norm_layer=None,
- bias=True,
- drop=0.0,
- use_conv=False,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
-
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = (
- norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
- )
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.norm(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
-
-class Attention(nn.Module):
- fused_attn: Final[bool]
-
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- norm_layer: nn.Module = nn.LayerNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.fused_attn = use_fused_attn()
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, N, C = x.shape
- qkv = (
- self.qkv(x)
- .reshape(B, N, 3, self.num_heads, self.head_dim)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
-
- if self.fused_attn:
- x = F.scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class CrossAttention(nn.Module):
- fused_attn: Final[bool]
-
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- norm_layer: nn.Module = nn.LayerNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.fused_attn = use_fused_attn()
-
- # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
- self.k = nn.Linear(dim, dim, bias=qkv_bias)
- self.v = nn.Linear(dim, dim, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, qkv: Tuple[torch.Tensor]) -> torch.Tensor:
- assert len(qkv) == 3, "the number of qkv should be 3"
- B_q, N_q, C_q = qkv[0].shape
- B_k, N_k, C_k = qkv[1].shape
- B_v, N_v, C_v = qkv[2].shape
- # qkv = (
- # self.qkv(x)
- # .reshape(B, N, 3, self.num_heads, self.head_dim)
- # .permute(2, 0, 3, 1, 4)
- # )
- # q, k, v = qkv.unbind(0)
- q = (
- self.q(qkv[0])
- .reshape(B_q, N_q, self.num_heads, self.head_dim)
- .permute(0, 2, 1, 3)
- )
- k = (
- self.k(qkv[1])
- .reshape(B_k, N_k, self.num_heads, self.head_dim)
- .permute(0, 2, 1, 3)
- )
- v = (
- self.v(qkv[2])
- .reshape(B_v, N_v, self.num_heads, self.head_dim)
- .permute(0, 2, 1, 3)
- )
- q, k = self.q_norm(q), self.k_norm(k)
-
- if self.fused_attn:
- x = F.scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- )
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
-
- x = x.transpose(1, 2).reshape(B_q, N_q, C_v)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-class LayerScale(nn.Module):
- def __init__(
- self,
- dim: int,
- init_values: float = 1e-5,
- inplace: bool = False,
- ) -> None:
- super().__init__()
- self.inplace = inplace
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
-
-
-class SelfTransformerEncoderBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- proj_drop: float = 0.0,
- attn_drop: float = 0.0,
- init_values: Optional[float] = None,
- drop_path: float = 0.0,
- act_layer: nn.Module = nn.GELU,
- norm_layer: nn.Module = nn.LayerNorm,
- mlp_layer: nn.Module = Mlp,
- ) -> None:
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- )
- self.ls1 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- self.norm2 = norm_layer(dim)
- self.mlp = mlp_layer(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- )
- self.ls2 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
-
-
-class CrossTransformerEncoderBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = False,
- qk_norm: bool = False,
- proj_drop: float = 0.0,
- attn_drop: float = 0.0,
- init_values: Optional[float] = None,
- drop_path: float = 0.0,
- act_layer: nn.Module = nn.GELU,
- norm_layer: nn.Module = nn.LayerNorm,
- mlp_layer: nn.Module = Mlp,
- ) -> None:
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = CrossAttention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_norm=qk_norm,
- attn_drop=attn_drop,
- proj_drop=proj_drop,
- norm_layer=norm_layer,
- )
- self.ls1 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- self.norm2 = norm_layer(dim)
- self.mlp = mlp_layer(
- in_features=dim,
- hidden_features=int(dim * mlp_ratio),
- act_layer=act_layer,
- drop=proj_drop,
- )
- self.ls2 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- def forward(self, qkv: Tuple[torch.Tensor]) -> torch.Tensor:
- assert len(qkv) == 3, "the number of qkv should be 3"
- x = qkv[0] + self.drop_path1(
- self.ls1(
- self.attn((self.norm1(qkv[0]), self.norm1(qkv[1]), self.norm1(qkv[2])))
- )
- )
- # x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return (x, qkv[1], qkv[2])
diff --git a/code/xtuner/model/dynamic_llava/dynamic_llava_qwen.py b/code/xtuner/model/dynamic_llava/dynamic_llava_qwen.py
deleted file mode 100644
index e6f16e2150f1b7250cfd5689ce8d3b561f6baf3b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/dynamic_llava_qwen.py
+++ /dev/null
@@ -1,1118 +0,0 @@
-from xtuner.model.utils import (
- prepare_inputs_labels_for_multimodal as _legacy_prepare_m4,
-)
-from xtuner.model.utils import find_all_linear_names, find_all_linear_names_for_dynamic_qwen2
-from mmengine.registry import MODELS
-
-from xtuner.model.llava import (
- LLaVAModel as _BaseLLaVA, # reuse builder helpers via composition
-)
-from typing import Dict, List, Optional, Tuple, Union
-from mmengine.model import BaseModel
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from xtuner.registry import BUILDER
-from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from peft import LoraConfig, TaskType, PeftConfig
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from transformers import (
- AutoConfig,
-)
-from .dynamic_qwen import (
- DynamicQwen2ForCausalLM,
- SparseConfig,
-)
-from collections import OrderedDict
-from typing import Iterable
-
-from xtuner.model.torchscale.model.LongNet import make_longnet_from_name
-from xtuner.model.modules import ProjectorModel, ProjectorConfig, dispatch_modules
-from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX
-from xtuner.model.utils import traverse_dict, LoadWoInit, make_inputs_require_grad
-from peft.tuners.lora import LoraLayer
-
-def list_lora_modules(model):
- rows = []
- for name, module in model.named_modules():
- if isinstance(module, LoraLayer):
- # active adapters on this module
- adapters = list(getattr(module, "lora_A", {}).keys())
- # try to infer rank per adapter from A's weight shape [r, in_features]
- ranks = {a: getattr(module, "lora_A", {})[a].weight.shape[0] for a in adapters}
- rows.append((name, module.__class__.__name__, adapters, ranks))
- return rows
-
-def print_lora_modules(model):
- rows = list_lora_modules(model)
- if not rows:
- print("No LoRA modules found.")
- return
- print(f"Found {len(rows)} LoRA-injected modules:\n")
- for name, cls, adapters, ranks in rows:
- rs = ", ".join(f"{a}: r={ranks.get(a, '?')}" for a in adapters) or "—"
- print(f"- {name} [{cls}] adapters=({rs})")
-
-
-# --- predictor init helpers -----------------------------------------------
-def _init_predictor_module(m: nn.Module):
- """Initialize predictor submodules only (Linear/LN)."""
- for mod in m.modules():
- if isinstance(mod, nn.Linear):
- if mod.out_features == 2:
- # logits head: small init keeps early training stable
- nn.init.trunc_normal_(mod.weight, std=1e-3)
- else:
- nn.init.xavier_uniform_(mod.weight, gain=math.sqrt(2.0))
- if mod.bias is not None:
- nn.init.zeros_(mod.bias)
- elif isinstance(mod, nn.LayerNorm):
- if mod.elementwise_affine:
- nn.init.ones_(mod.weight)
- nn.init.zeros_(mod.bias)
-
-def _collect_predictor_modules(base_llm: nn.Module):
- """Yield (name, module) for the 3 predictor stacks inside the Qwen model."""
- want_names = (
- "model.image_score_predictor",
- "model.output_text_score_predictor",
- "model.instruct_score_predictor",
- )
- # Be permissive about PEFT wrappers (extra 'base_model.model.model.' etc.)
- lowers = [w.lower() for w in want_names]
- for name, module in base_llm.named_modules():
- lname = name.lower()
- if any(l in lname for l in lowers):
- # Return the *root* predictor module (not the internals)
- # Filter out nested children: only top-most matches.
- if any(lname.endswith(w) for w in ("image_score_predictor",
- "output_text_score_predictor",
- "instruct_score_predictor")):
- yield name, module
-
-def _init_predictors_if_missing(base_llm: nn.Module, missing_keys: list[str], verbose: bool = True):
- """Initialize predictors if their params were missing from the loaded checkpoint."""
- predictor_roots = (
- "model.image_score_predictor",
- "model.output_text_score_predictor",
- "model.instruct_score_predictor",
- # PEFT nesting variants:
- "base_model.model.model.image_score_predictor",
- "base_model.model.model.output_text_score_predictor",
- "base_model.model.model.instruct_score_predictor",
- )
- need_init = any(any(k.startswith(root) for root in predictor_roots) for k in missing_keys)
- if not need_init:
- return False
-
- inited = []
- for name, mod in _collect_predictor_modules(base_llm):
- _init_predictor_module(mod)
- inited.append(name)
- if verbose and inited:
- print_log(f"Initialized predictor modules (missing in ckpt): {inited}", "current")
- return bool(inited)
-
-def _force_init_all_predictors(base_llm: nn.Module, verbose: bool = True):
- """Always initialize all predictor stacks (use when building from a base HF model)."""
- inited = []
- for name, mod in _collect_predictor_modules(base_llm):
- _init_predictor_module(mod)
- inited.append(name)
- if verbose and inited:
- print_log(f"Initialized predictor modules: {inited}", "current")
- return bool(inited)
-
-@BUILDER.register_module()
-class DynamicLLaVAQwen25(BaseModel):
- """XTuner LLaVA variant for Qwen2.5 + QLoRA + dynamic sparsity, LongNet intact."""
- _QWEN2_LORA_TARGETS = ("q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj")
- def __init__(
- self,
- llm: Dict,
- freeze_llm: bool = True,
- pretrained_pth: Optional[str] = None,
- projector_depth: int = 2,
- llm_lora: Optional[Dict] = None,
- use_activation_checkpointing: bool = True,
- max_position_embeddings=None,
- hidden_size: int = 512,
- train_stage: str = '1',
- enable_long_net: bool = True,
- long_net_pth: Optional[str] = None,
- projector_pth: Optional[str] = None,
- image_feature_length: int = 196,
- sparse_config: Optional[Dict] = None,
- divprune_ratio: float = 0.1,
- ):
- super().__init__()
- self.hidden_size = hidden_size
- self.image_feature_length = image_feature_length
- self.divprune_ratio = divprune_ratio
-
- # # 1) Build LLM config (Auto class -> trust_remote_code is OK here)
- # llm_cfg = AutoConfig.from_pretrained(llm["pretrained_model_name_or_path"], trust_remote_code=True)
- # llm_cfg.sparse_config = (sparse_config or SparseConfig().__dict__)
-
- # # Optional quantization
- # quant_cfg = llm.get("quantization_config", None)
- # if isinstance(quant_cfg, dict) and "type" in quant_cfg:
- # quant_type = quant_cfg.pop("type")
- # if isinstance(quant_type, str):
- # if quant_type.lower() == "bitsandbytesconfig":
- # from transformers import BitsAndBytesConfig as _BB
- # quant_type = _BB
- # quant_cfg = quant_type(**quant_cfg)
-
- # # 2) Build model (NON-Auto class -> DO NOT pass trust_remote_code)
- # self.llm: DynamicQwen2ForCausalLM = DynamicQwen2ForCausalLM.from_pretrained(
- # llm["pretrained_model_name_or_path"],
- # config=llm_cfg,
- # torch_dtype=llm.get("torch_dtype", None),
- # attn_implementation=llm.get("attn_implementation", "flash_attention_2"),
- # quantization_config=quant_cfg,
- # low_cpu_mem_usage=True,
- # )
-
- self.enable_long_net = enable_long_net
-
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
-
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = False #False
- self.freeze_long_net = True #False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- # print('sparse_config', sparse_config)
- llm = self._dispatch_lm_model_cfg(llm, sparse_config, max_position_embeddings)
- # add sparse config to llm config
- # if isinstance(llm, ConfigDict):
- if sparse_config is not None:
- llm.sparse_config = SparseConfig(**sparse_config)
- else:
- llm.sparse_config = SparseConfig()
-
-
- self.llm = self._build_from_cfg_or_module(llm)
- print_log("Building DynamicLLaVAQwen25 with LLM: {}".format(self.llm.__class__.__name__), "current")
- if self.training:
- _force_init_all_predictors(self._get_llm_base())
-
- self.llm.config.use_cache = False
- self.config = self.llm.config
- # dispatch_modules(self.llm)
-
- # print_log("Dtype of llm: {}".format(self.llm.dtype), "current")
-
- # 4) LongNet + Projector (reuse your existing)
- # from .llava import make_longnet_from_name, ProjectorModel, ProjectorConfig
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name) if enable_long_net else nn.Identity()
- if not isinstance(self.LongNet_encoder, nn.Identity):
- self.LongNet_encoder = self.LongNet_encoder.to(device=self.llm.device, dtype=self.llm.dtype)
-
- self.projector = ProjectorModel(ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=projector_depth,
- )).to(self.llm.dtype)
- # dispatch_modules(self.llm)
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
-
- self.projector.enable_input_require_grads()
- # self.LongNet_encoder.enable_input_require_grads()
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- # 3) LoRA (QLoRA)
- if llm_lora is not None:
- lora_config = self._parse_lora_config(llm_lora)
- from peft import get_peft_model, prepare_model_for_kbit_training
- prepare_model_for_kbit_training(self.llm, use_gradient_checkpointing=use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names_for_dynamic_qwen2(self.llm)
- lora_config.target_modules = modules
- lora_config.modules_to_save = ["embed_tokens", "lm_head"]
- self.llm = get_peft_model(self.llm, lora_config)
- print_lora_modules(self.llm)
-
- if self.training:
- self._enable_predictors_training()
-
- # print_log("Dtype of llm after LoRA: {}".format(self.llm.dtype), "current")
- if long_net_pth:
- from safetensors.torch import load_file
- print_log(f"Loading LongNet from {long_net_pth}", "current")
- self.LongNet_encoder.load_state_dict(load_file(long_net_pth, device="cpu"), strict=False)
- self.LongNet_encoder.to(self.llm.dtype)
-
- if projector_pth:
- from safetensors.torch import load_file
- print_log(f"Loading projector from {projector_pth}", "current")
- self.projector.load_state_dict(load_file(projector_pth, device="cpu"), strict=False)
- self.projector.to(self.llm.dtype)
-
- #5) Tokenizer (Auto class -> trust_remote_code is OK)
- # from transformers import AutoTokenizer
- # tok_name = llm.get("pretrained_model_name_or_path")
- # self.tokenizer = AutoTokenizer.from_pretrained(tok_name, trust_remote_code=True, use_fast=True)
- # self._user_ids = self.tokenizer.encode("USER:", add_special_tokens=False)
- # self._assistant_ids = self.tokenizer.encode("ASSISTANT:", add_special_tokens=False)
-
-
- self._is_init = True
-
- self.is_first_iter = True
-
-
- def _get_llm_base(self):
- """Return the underlying (possibly de-PEFTed) base model object."""
- llm = self.llm
- # PEFT wrappers expose get_base_model()
- if hasattr(llm, "get_base_model"):
- try:
- return llm.get_base_model()
- except Exception:
- pass
- return llm
-
- def pairwise_cosine_similarity(self, matrix):
- norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
- cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
- return cosine_similarity
-
- def pairwise_l1_distance(matrix: torch.Tensor) -> torch.Tensor:
- """
- Compute the full pairwise L1 (Manhattan) distance matrix
- for an [N, D] tensor.
- """
- # torch.cdist with p=1 computes L1 distance
- return torch.cdist(matrix, matrix, p=1)
-
-
- def DivPrune(self, visual_feature_vectors, image_feature_length, cosine_matrix=None, threshold_ratio=0.1):
- threshold_terms = int(round(threshold_ratio * image_feature_length))
- if cosine_matrix is None:
- cosine_matrix = 1.0 - (self.pairwise_cosine_similarity(visual_feature_vectors))
-
- s = torch.empty(threshold_terms, dtype=torch.long, device=visual_feature_vectors.device)
- for i in range(threshold_terms):
- if i == 0:
- m2 = cosine_matrix
- else:
- m2 = torch.index_select(cosine_matrix, 0, torch.index_select(s, 0, torch.arange(0, i, device=cosine_matrix.device)))
-
- if i == 0:
- scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
- else:
- scores = torch.min(m2, dim=0).values
-
- phrase_to_add_idx = torch.argmax(scores)
- s[i] = phrase_to_add_idx
- return s, cosine_matrix
-
- def _enable_predictors_training(self):
- """
- Turn ON requires_grad for predictor stacks while keeping the rest
- of the LLM frozen (except LoRA params which PEFT already enables).
- """
- base = self._get_llm_base()
-
- want = (
- "image_score_predictor", # VisionPredictor
- "output_text_score_predictor",# TextPredictor for outputs
- "instruct_score_predictor", # TextPredictor for instructions
- # fallbacks if names differ:
- "visionpredictor",
- "textpredictor",
- "instructpredictor",
- )
-
- made_trainable = []
- for name, module in base.named_modules():
- lname = name.lower()
- if any(w in lname for w in want):
- module.requires_grad_(True)
- for p in module.parameters():
- p.requires_grad = True
- made_trainable.append(name)
-
- if made_trainable:
- print_log(f"Predictor modules set trainable: {made_trainable}", "current")
- else:
- print_log("No predictor modules found to unfreeze (ok if predictors are disabled).", "current")
-
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
- # ------------------------------
- # Multimodal packing (patched)
- # ------------------------------
- def _prepare_inputs_labels_for_multimodal_dynamic(
- self,
- input_ids: torch.LongTensor = None,
- position_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- labels: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- text_features: Optional[torch.FloatTensor] = None,
- *,
- special_user_ids: Optional[torch.LongTensor] = None, # e.g., tokenized "USER:" sequence
- return_indices: bool = False
- ):
- """
- Returns:
- - if return_indices=False (default): dict (same as before)
- - if return_indices=True: (dict, (input_embeds_indices,))
- where input_embeds_indices is a list[dict] with keys:
- system: [start, end]
- image: [start, end] # union (min start, max end) of all image blocks
- images: list[[start, end], ...] # each image block range
- instruct: [start, end]
- answer: [start, end]
- last_instruct: [start, end] # if special_user_ids given; else falls back to instruct
- """
- if pixel_values is None:
- out = {
- 'input_ids': input_ids,
- 'position_ids': position_ids,
- 'attention_mask': attention_mask,
- 'past_key_values': past_key_values,
- 'inputs_embeds': None,
- 'labels': labels
- }
- return (out, (None,)) if return_indices else out
-
- # snapshot originals to decide whether to null them at the end
- _labels = labels
- _position_ids = position_ids
- _attention_mask = attention_mask
-
- # defaults
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
- else:
- attention_mask = attention_mask.bool()
- if position_ids is None:
- position_ids = torch.arange(0, input_ids.shape[1],
- dtype=torch.long, device=input_ids.device)
- if labels is None:
- labels = torch.full_like(input_ids, IGNORE_INDEX)
-
- # strip padding per attention_mask
- input_ids_list = [ids[mask] for ids, mask in zip(input_ids, attention_mask)]
- labels_list = [lab[mask] for lab, mask in zip(labels, attention_mask)]
-
- new_inputs_embeds = []
- new_labels = []
- per_sample_segments = [] # will hold raw (before trunc/pad) segment ranges for indices
- cur_image_idx = 0
-
- embed_tokens = self.llm.get_input_embeddings()
-
- for b_idx, cur_input_ids in enumerate(input_ids_list):
- cur_labels = labels_list[b_idx]
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum().item()
-
- # split text around IMAGE_TOKEN_INDEX
- image_token_positions = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
- image_token_indices = [-1] + image_token_positions + [cur_input_ids.shape[0]]
-
- # collect text chunks and their labels (without IMAGE tokens)
- text_chunks_ids = []
- text_chunks_labels = []
- for i in range(len(image_token_indices) - 1):
- s = image_token_indices[i] + 1
- e = image_token_indices[i + 1]
- if s < e:
- text_chunks_ids.append(cur_input_ids[s:e])
- text_chunks_labels.append(cur_labels[s:e])
- else:
- text_chunks_ids.append(cur_input_ids[s:e]) # empty
- text_chunks_labels.append(cur_labels[s:e])
-
- # embed text (unless text_features provided for this batch)
- if text_features is None:
- if len(text_chunks_ids) > 0:
- flat_ids = torch.cat(text_chunks_ids) if text_chunks_ids else cur_input_ids.new_empty((0,))
- flat_embeds = embed_tokens(flat_ids) if flat_ids.numel() else flat_ids.new_zeros((0, embed_tokens.embedding_dim))
- split_sizes = [t.shape[0] for t in text_chunks_ids]
- text_chunks_embeds = list(torch.split(flat_embeds, split_sizes, dim=0))
- else:
- text_chunks_embeds = []
- else:
- # assume text_features[b_idx] already matches concatenated text chunks (no images)
- # split along text chunk lengths to keep consistent indexing
- split_sizes = [t.shape[0] for t in text_chunks_ids]
- if sum(split_sizes) == text_features[b_idx].shape[0]:
- text_chunks_embeds = list(torch.split(text_features[b_idx], split_sizes, dim=0))
- else:
- # fallback: embed with token embeddings if provided features don't match
- flat_ids = torch.cat(text_chunks_ids) if text_chunks_ids else cur_input_ids.new_empty((0,))
- flat_embeds = embed_tokens(flat_ids) if flat_ids.numel() else flat_ids.new_zeros((0, embed_tokens.embedding_dim))
- text_chunks_embeds = list(torch.split(flat_embeds, split_sizes, dim=0))
-
- # Interleave: text_chunk[i], (image i), text_chunk[i+1], ...
- cur_new_embeds_parts = []
- cur_new_labels_parts = []
- seg_ranges_text = [] # list of [start, end) in embed positions for each text chunk
- seg_ranges_images = [] # list of [start, end) in embed positions for each image block
-
- offset = 0
- for i in range(len(text_chunks_embeds)):
- # append text chunk i
- te = text_chunks_embeds[i]
- tl = text_chunks_labels[i]
- if te.shape[0] > 0:
- cur_new_embeds_parts.append(te)
- cur_new_labels_parts.append(tl)
- seg_ranges_text.append([offset, offset + te.shape[0]])
- offset += te.shape[0]
- else:
- seg_ranges_text.append([offset, offset]) # empty
-
- # append image i (if exists)
- if i < num_images:
- cur_pixels = pixel_values[cur_image_idx]
- cur_image_idx += 1
- if cur_pixels.shape[0] > 0:
- cur_new_embeds_parts.append(cur_pixels)
- # labels for image tokens are ignored
- cur_img_labels = torch.full((cur_pixels.shape[0],),
- IGNORE_INDEX,
- device=cur_labels.device,
- dtype=cur_labels.dtype)
- cur_new_labels_parts.append(cur_img_labels)
- seg_ranges_images.append([offset, offset + cur_pixels.shape[0]])
- offset += cur_pixels.shape[0]
- else:
- seg_ranges_images.append([offset, offset])
-
- # If no image tokens, still advance image idx once to keep parity with caller expectations
- if num_images == 0:
- # match original behavior
- _ = pixel_values[cur_image_idx]
- cur_image_idx += 1
-
- # stitch
- cur_new_embeds = torch.cat(cur_new_embeds_parts, dim=0) if len(cur_new_embeds_parts) else \
- torch.zeros((0, embed_tokens.embedding_dim), device=cur_input_ids.device, dtype=embed_tokens.weight.dtype)
- cur_new_labels = torch.cat(cur_new_labels_parts, dim=0) if len(cur_new_labels_parts) else \
- torch.full((0,), IGNORE_INDEX, device=cur_input_ids.device, dtype=cur_labels.dtype)
-
- # derive segment indices (pre-trunc/pad) similar to the second function
- cur_len = cur_new_embeds.shape[0]
- if cur_len > 0:
- # answer starts at first non-IGNORE label; if none, set to cur_len
- non_ignore = torch.nonzero(cur_new_labels != IGNORE_INDEX, as_tuple=False)
- answer_start = non_ignore[0].item() if non_ignore.numel() else cur_len
- else:
- answer_start = 0
-
- # system: everything before first image (i.e., first text chunk)
- system_start, system_end = seg_ranges_text[0] if seg_ranges_text else [0, 0]
-
- # instruct: after first image block (if any) up to answer_start
- if seg_ranges_images:
- instruct_start = seg_ranges_images[0][1]
- else:
- instruct_start = system_end
- instruct_end = min(answer_start, cur_len)
- instruct_start = min(instruct_start, instruct_end)
-
- # union image range and per-image ranges
- if seg_ranges_images:
- image_start_union = seg_ranges_images[0][0]
- image_end_union = seg_ranges_images[-1][1]
- else:
- image_start_union = image_end_union = instruct_start # empty
-
- # last_instruct: optionally find last "USER:" marker within instruct tokens
- # Compute within original token ids to locate the last occurrence,
- # then map back into embed coordinates accounting for interleaved images.
- if special_user_ids is not None and special_user_ids.numel() > 0 and len(image_token_positions) > 0:
- # define instruct slice in original ids: after first IMAGE token up to answer start in original ids
- first_image_pos = image_token_positions[0]
- # last index in original where labels == IGNORE_INDEX (prompt)
- non_ignore_orig = torch.nonzero(cur_labels != IGNORE_INDEX, as_tuple=False)
- answer_start_orig = non_ignore_orig[0].item() if non_ignore_orig.numel() else cur_labels.shape[0]
- instr_ids = cur_input_ids[first_image_pos + 1: answer_start_orig]
- # search for the last occurrence of special_user_ids in instr_ids
- last_user_pos = None
- if instr_ids.numel() >= special_user_ids.numel():
- # naive sliding window
- L = special_user_ids.shape[0]
- for i in range(instr_ids.shape[0] - L, -1, -1):
- if torch.equal(instr_ids[i:i+L], special_user_ids.to(instr_ids.device)):
- last_user_pos = i
- break
- if last_user_pos is not None:
- # map this position into embed coordinates:
- # it lies in concatenation of text_chunks_ids[1:], before answer.
- remaining = last_user_pos
- last_seg_idx = None
- in_seg_offset = 0
- # walk text chunks 1.. to locate which chunk contains it
- for ti in range(1, len(text_chunks_ids)):
- tlen = text_chunks_ids[ti].shape[0]
- if remaining < tlen:
- last_seg_idx = ti
- in_seg_offset = remaining
- break
- remaining -= tlen
- if last_seg_idx is None:
- # falls back to start of instruct if not found
- last_instruct_start_embed = instruct_start
- else:
- # embedding start of that text chunk:
- last_text_range = seg_ranges_text[last_seg_idx]
- last_instruct_start_embed = last_text_range[0] + in_seg_offset
- else:
- last_instruct_start_embed = instruct_start
- else:
- last_instruct_start_embed = instruct_start
-
- per_sample_segments.append({
- "system": [int(system_start), int(min(system_end, cur_len))],
- "image": [int(image_start_union), int(min(image_end_union, cur_len))],
- "images": [[int(s), int(min(e, cur_len))] for (s, e) in seg_ranges_images],
- "instruct": [int(instruct_start), int(instruct_end)],
- "answer": [int(answer_start), int(cur_len)],
- "last_instruct": [int(last_instruct_start_embed), int(instruct_end)],
- })
-
- new_inputs_embeds.append(cur_new_embeds)
- new_labels.append(cur_new_labels)
-
- # Optional truncation to tokenizer max length (like the second function)
- tokenizer_model_max_length = getattr(self.llm.config, "tokenizer_model_max_length", None)
- if tokenizer_model_max_length is not None:
- truncated_inputs = []
- truncated_labels = []
- for i, (emb, lab, segs) in enumerate(zip(new_inputs_embeds, new_labels, per_sample_segments)):
- if emb.shape[0] > tokenizer_model_max_length:
- emb = emb[:tokenizer_model_max_length]
- lab = lab[:tokenizer_model_max_length]
- # clamp segment ranges
- for k, v in segs.items():
- if k == "images":
- for r in v:
- r[0] = min(r[0], tokenizer_model_max_length)
- r[1] = min(r[1], tokenizer_model_max_length)
- continue
- v[0] = min(v[0], tokenizer_model_max_length)
- v[1] = min(v[1], tokenizer_model_max_length)
- truncated_inputs.append(emb)
- truncated_labels.append(lab)
- new_inputs_embeds = truncated_inputs
- new_labels = truncated_labels
-
- # Batch combine with left/right padding support
- max_len = max(x.shape[0] for x in new_inputs_embeds) if new_inputs_embeds else 0
- batch_size = len(new_inputs_embeds)
-
- new_inputs_embeds_padded = []
- new_labels_padded = torch.full((batch_size, max_len),
- IGNORE_INDEX,
- dtype=new_labels[0].dtype if new_labels else torch.long,
- device=new_labels[0].device if new_labels else input_ids.device)
- attn = torch.zeros((batch_size, max_len),
- dtype=attention_mask.dtype, device=attention_mask.device)
- pos = torch.zeros((batch_size, max_len),
- dtype=position_ids.dtype, device=position_ids.device)
-
- pad_side = getattr(self.llm.config, "tokenizer_padding_side", "right")
- # adjust (pad) indices after we know per-sample cur_len
- input_embeds_indices = []
-
- for i, (emb, lab, segs) in enumerate(zip(new_inputs_embeds, new_labels, per_sample_segments)):
- cur_len = emb.shape[0]
- if pad_side == "left":
- pad_len = max_len - cur_len
- padded = torch.cat([torch.zeros((pad_len, emb.shape[1]), dtype=emb.dtype, device=emb.device), emb], dim=0)
- new_inputs_embeds_padded.append(padded)
- if cur_len > 0:
- new_labels_padded[i, -cur_len:] = lab
- attn[i, -cur_len:] = True
- pos[i, -cur_len:] = torch.arange(0, cur_len, dtype=pos.dtype, device=pos.device)
- # shift segment ranges by pad_len
- adj_segs = {}
- for k, v in segs.items():
- if k == "images":
- adj_segs[k] = [[r[0]+pad_len, r[1]+pad_len] for r in v]
- else:
- adj_segs[k] = [v[0]+pad_len, v[1]+pad_len]
- input_embeds_indices.append(adj_segs)
- else:
- # right pad
- padded = torch.cat([emb, torch.zeros((max_len - cur_len, emb.shape[1]), dtype=emb.dtype, device=emb.device)], dim=0)
- new_inputs_embeds_padded.append(padded)
- if cur_len > 0:
- new_labels_padded[i, :cur_len] = lab
- attn[i, :cur_len] = True
- pos[i, :cur_len] = torch.arange(0, cur_len, dtype=pos.dtype, device=pos.device)
- # no shift needed
- input_embeds_indices.append(segs)
-
- new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) if new_inputs_embeds_padded else \
- torch.zeros((0, 0, embed_tokens.embedding_dim), device=input_ids.device, dtype=embed_tokens.weight.dtype)
-
- # finalize (respect original None-ness)
- if _labels is None:
- final_labels = None
- else:
- final_labels = new_labels_padded
-
- if _attention_mask is None:
- final_attn = None
- else:
- final_attn = attn.to(dtype=_attention_mask.dtype)
-
- if _position_ids is None:
- final_pos = None
- else:
- final_pos = pos
-
- out = {
- 'input_ids': None,
- 'position_ids': final_pos,
- 'attention_mask': final_attn,
- 'past_key_values': past_key_values,
- 'inputs_embeds': new_inputs_embeds,
- 'labels': final_labels
- }
-
- return (out, (input_embeds_indices,)) if return_indices else out
-
- # ------------------------------
- # Forward / Predict / Loss
- # ------------------------------
- def _encode_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
- """pixel_values are *precomputed* features [B, N, C_feat] -> Projected to LLM hidden."""
- feat_to_proj = pixel_values.to(self.llm.dtype)
- # print_log("dtype of pixel_values: {}".format(feat_to_proj.dtype), "current")
- # print_log("dtype of long_net_encoder: {}".format(self.LongNet_encoder.dtype), "current")
- if not isinstance(self.LongNet_encoder, nn.Identity):
- out = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))
- if isinstance(out, dict) and "encoder_out" in out:
- eo = out["encoder_out"]
- feat_to_proj = eo.permute(1, 0, 2) if eo.dim() == 3 else eo
- else:
- feat_to_proj = out
-
- if self.divprune_ratio < 1.0:
- # This assumes feat_to_proj has a shape of [batch_size, num_tokens, feature_dim].
- # It iterates through each item in the batch, prunes it, and stacks the results.
- # This works if num_tokens is the same for all items in the batch, which is a
- # standard assumption for batched processing.
- pruned_batch_features = []
- for visual_tokens in feat_to_proj: # Iterate over the batch dimension
- img_feature_len = visual_tokens.shape[0]
- selected_indices, _ = self.DivPrune(
- visual_tokens,
- img_feature_len,
- threshold_ratio=self.divprune_ratio
- )
- selected_indices = torch.sort(selected_indices).values
- pruned_features = visual_tokens[selected_indices]
- pruned_batch_features.append(pruned_features)
-
- # Stack the list of pruned tensors back into a single batch tensor
- feat_to_proj = torch.stack(pruned_batch_features, dim=0)
-
- # print("Shape of features before projector: {}".format(feat_to_proj.shape), "current")
- return self.projector(feat_to_proj)
-
- def _forward(self, data, data_samples=None):
- return self.llm(**data)
-
- def predict(self, data, data_samples=None):
- pixel_values = None
- if "pixel_values" in data and data["pixel_values"] is not None:
- pixel_values = self._encode_features(data["pixel_values"])
-
- packed, (input_embeds_indices,) = self._prepare_inputs_labels_for_multimodal_dynamic(
- input_ids=data["input_ids"],
- labels=data.get("labels"),
- attention_mask=data.get("attention_mask"),
- position_ids=data.get("position_ids"),
- past_key_values=data.get("past_key_values"),
- pixel_values=pixel_values,
- return_indices=True, # <<< important
- )
-
- attn = None if getattr(self.llm.config, "_attn_implementation", "") == "flash_attention_2" else packed["attention_mask"]
-
- outputs = self.llm(
- inputs_embeds=packed["inputs_embeds"],
- attention_mask=attn,
- position_ids=packed["position_ids"],
- labels=packed["labels"],
- input_embeds_indices=input_embeds_indices, # <<< use the unpacked indices
- use_cache=True,
- )
- return [{"logits": l} for l in outputs.logits]
-
- def compute_loss(self, data, data_samples=None):
- pixel_values = None
- if "pixel_values" in data and data["pixel_values"] is not None:
- pixel_values = self._encode_features(data["pixel_values"])
-
- # ask for indices explicitly
- packed, (input_embeds_indices,) = self._prepare_inputs_labels_for_multimodal_dynamic(
- input_ids=data["input_ids"],
- labels=data.get("labels"),
- attention_mask=data.get("attention_mask"),
- position_ids=data.get("position_ids"),
- past_key_values=data.get("past_key_values"),
- pixel_values=pixel_values,
- return_indices=True, # <<< important
- )
-
- # FA2-safe: do not pass a padding mask
- attn = None if getattr(self.llm.config, "_attn_implementation", "") == "flash_attention_2" else packed["attention_mask"]
-
- # print_log("Shape of inputs_embeds: {}".format(packed["inputs_embeds"].shape), "current")
- # print_log('input_embeds_indices shape: {}'.format(input_embeds_indices), "current")
- outputs = self.llm(
- inputs_embeds=packed["inputs_embeds"],
- attention_mask=attn,
- position_ids=packed["position_ids"],
- labels=packed["labels"],
- input_embeds_indices=input_embeds_indices, # <<< use the unpacked indices
- use_cache=False,
- )
-
- loss = outputs.loss
-
- return {"loss": loss}
-
-
- # mmengine requires a concrete forward; dispatch to compute_loss/predict
- def forward(self, data, data_samples=None, mode: str = "loss"):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- if mode == "loss":
- return self.compute_loss(data, data_samples)
- elif mode == "predict":
- return self.predict(data, data_samples)
- elif mode == "tensor":
- return self._forward(data, data_samples)
- else:
- raise KeyError(f"Invalid mode: {mode}")
- # ---------- helpers ----------
- def _is_peft(self) -> bool:
- try:
- from peft import PeftModel # type: ignore
- except Exception:
- return False
- return isinstance(self.llm, PeftModel)
-
- def _gather_predictor_keys_from_llm(self) -> Iterable[str]:
- """
- Collect parameter names (relative to self.llm) that belong to the dynamic predictors
- we added inside the decoder: image_score_predictor / output_text_score_predictor / instruct_score_predictor.
- """
- predictor_roots = (
- "model.image_score_predictor",
- "model.output_text_score_predictor",
- "model.instruct_score_predictor",
- # Some PEFT wrappers add an extra 'model' hop; be permissive:
- "base_model.model.model.image_score_predictor",
- "base_model.model.model.output_text_score_predictor",
- "base_model.model.model.instruct_score_predictor",
- )
- keys = []
- for name, _ in self.llm.named_parameters():
- if any(name.startswith(root) for root in predictor_roots):
- keys.append(name)
- return keys
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
- # def _parse_lora_config(self, lora_config):
- # """Accept dict/ConfigDict or PeftConfig; always return a PeftConfig."""
- # if isinstance(lora_config, (Config, ConfigDict, dict)):
- # cfg = dict(lora_config)
-
- # # 1) sanitize/ensure granular targets (avoid broad names like "layers", "mlp", "self_attn")
- # targets = cfg.get("target_modules", None)
- # if not targets:
- # targets = list(self._QWEN2_LORA_TARGETS)
- # else:
- # clean = []
- # for t in targets:
- # for allowed in self._QWEN2_LORA_TARGETS:
- # if allowed == t or allowed in t:
- # clean.append(allowed)
- # targets = sorted(set(clean)) or list(self._QWEN2_LORA_TARGETS)
- # cfg["target_modules"] = targets
-
- # # 2) default knobs that work on Qwen2
- # cfg.setdefault("task_type", TaskType.CAUSAL_LM)
- # cfg.setdefault("bias", "none")
-
- # return BUILDER.build(cfg, PeftConfig, "peft")
-
- # # already a PeftConfig
- # return lora_config
- # ---------- save ----------
- def state_dict(self, destination=None, prefix:str = '', keep_vars:bool = False) -> "OrderedDict[str, torch.Tensor]":
- """
- Save a compact checkpoint:
- - LoRA adapter weights from self.llm (if PEFT)
- - Dynamic predictor modules inside the LLM
- - Projector
- - LongNet encoder
- This keeps resumes fast & small under QLoRA.
- """
- sd = OrderedDict()
-
- # 1) Projector & LongNet (live on the wrapper)
- for k, v in self.projector.state_dict(keep_vars=keep_vars).items():
- sd[prefix + f"projector.{k}"] = v
- if hasattr(self, "LongNet_encoder") and isinstance(self.LongNet_encoder, nn.Module):
- for k, v in self.LongNet_encoder.state_dict(keep_vars=keep_vars).items():
- sd[prefix + f"LongNet_encoder.{k}"] = v
-
- # 2) LLM bits
- if self._is_peft():
- # LoRA adapters only
- from peft import get_peft_model_state_dict # type: ignore
- lora_sd = get_peft_model_state_dict(self.llm)
- for k, v in lora_sd.items():
- sd[prefix + f"llm.{k}"] = v
-
- # plus dynamic predictors that sit inside the decoder (not part of LoRA adapters)
- llm_full_sd = self.llm.state_dict(keep_vars=keep_vars)
- for k in self._gather_predictor_keys_from_llm():
- sd[prefix + f"llm.{k}"] = llm_full_sd[k]
- else:
- # No PEFT: save ONLY the predictor heads inside LLM to keep the checkpoint small;
- # (base LLM weights are assumed frozen or recoverable from pretrained path)
- llm_full_sd = self.llm.state_dict(keep_vars=keep_vars)
- for k in self._gather_predictor_keys_from_llm():
- sd[prefix + f"llm.{k}"] = llm_full_sd[k]
-
- return sd
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, sparse_config, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
-
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path)
-
- llm_cfg.sparse_config = (sparse_config or SparseConfig().__dict__)
- # cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- # traverse_dict(cfg_or_mod)
- print_log("Building DynamicLLaVAQwen25 from config dict", "current")
-
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
-
- # ---------- load ----------
- def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = False):
- """
- Load from a compact checkpoint produced by state_dict() above.
- We distribute keys into the right submodules and ignore everything else.
- """
- missing_keys_all = []
- unexpected_keys_all = []
- error_msgs_all = []
-
- # Split by prefix
- proj_sd = OrderedDict()
- ln_sd = OrderedDict()
- llm_sd = OrderedDict()
- for k, v in state_dict.items():
- if k.startswith("projector."):
- proj_sd[k[len("projector."):]] = v
- elif k.startswith("LongNet_encoder."):
- ln_sd[k[len("LongNet_encoder."):]] = v
- elif k.startswith("llm."):
- llm_sd[k[len("llm."):]] = v
- else:
- # Allow extra keys; record as unexpected
- unexpected_keys_all.append(k)
-
- # 1) Load projector
- if proj_sd:
- mk = self.projector.load_state_dict(proj_sd, strict=False)
- missing_keys_all.extend([f"projector.{k}" for k in mk.missing_keys])
- unexpected_keys_all.extend([f"projector.{k}" for k in mk.unexpected_keys])
-
- # 2) Load LongNet
- if ln_sd and hasattr(self, "LongNet_encoder") and isinstance(self.LongNet_encoder, nn.Module):
- mk = self.LongNet_encoder.load_state_dict(ln_sd, strict=False)
- missing_keys_all.extend([f"LongNet_encoder.{k}" for k in mk.missing_keys])
- unexpected_keys_all.extend([f"LongNet_encoder.{k}" for k in mk.unexpected_keys])
-
- # 3) Load LLM bits (LoRA adapters + predictors)
- if llm_sd:
- mk = self.llm.load_state_dict(llm_sd, strict=False)
- missing_keys_all.extend([f"llm.{k}" for k in mk.missing_keys])
- unexpected_keys_all.extend([f"llm.{k}" for k in mk.unexpected_keys])
-
- # Respect 'strict' like torch.nn.Module.load_state_dict
- if strict and (missing_keys_all or unexpected_keys_all):
- msgs = []
- if missing_keys_all:
- msgs.append(f"Missing keys: {missing_keys_all}")
- if unexpected_keys_all:
- msgs.append(f"Unexpected keys: {unexpected_keys_all}")
- raise RuntimeError("Error(s) in loading state_dict for DynamicLLaVAQwen25:\n\t" + "\n\t".join(msgs))
-
- # Return an object mirroring torch's IncompatibleKeys for convenience
- class _Compat:
- def __init__(self, missing, unexpected):
- self.missing_keys = missing
- self.unexpected_keys = unexpected
- return _Compat(missing_keys_all, unexpected_keys_all)
\ No newline at end of file
diff --git a/code/xtuner/model/dynamic_llava/dynamic_qwen.py b/code/xtuner/model/dynamic_llava/dynamic_qwen.py
deleted file mode 100644
index 63591e94c9d5eb062515abea32ad2ad369fe7d2b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/dynamic_llava/dynamic_qwen.py
+++ /dev/null
@@ -1,2089 +0,0 @@
-import math
-from typing import List, Optional, Tuple, Union
-import inspect
-import torch
-import warnings
-from dataclasses import dataclass, asdict
-# import torch.utils.checkpoint
-from torch import nn
-import torch.nn.functional as F
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from xtuner.registry import BUILDER
-from mmengine import print_log
-from transformers.activations import ACT2FN
-from .cache_utils import DynamicCachePlus
-from transformers.cache_utils import Cache
-from transformers.modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast,
- TokenClassifierOutput,
-)
-from transformers.modeling_attn_mask_utils import (
- AttentionMaskConverter,
- _prepare_4d_causal_attention_mask,
- _prepare_4d_causal_attention_mask_for_sdpa)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- is_flash_attn_greater_or_equal_2_10,
- logging,
- replace_return_docstrings,
-)
-
-try:
- from transformers.generation import GenerationMixin
-except Exception:
- from transformers.generation.utils import GenerationMixin
-
-from .custom_transformer_layer import SelfTransformerEncoderBlock
-
-# from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
-from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
-
-if is_flash_attn_2_available():
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
-
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
-
-
-logger = logging.get_logger(__name__)
-
-
-_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
-_CONFIG_FOR_DOC = "Qwen2Config"
-
-QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
- "Qwen/Qwen2-7B-beta",
- # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
-]
-def _finite_stats(name, t):
- if not torch.isfinite(t).all():
- bad = (~torch.isfinite(t)).sum().item()
- m = t.abs().max().item()
- raise RuntimeError(f"{name}: non-finite={bad}, max_abs={m}")
- # Also check magnitude to catch impending overflow
- if t.dtype in (torch.float16, torch.bfloat16) and t.abs().max() > 6e3:
- print(f"[warn] {name} max_abs ~ {t.abs().max().item():.1f}")
-
-@dataclass
-class SparseConfig:
- use_vision_predictor: bool = True
- use_text_predictor: bool = True
- use_output_text_predictor: bool = True
- use_instruct_predictor: bool = False
- sparse_layer: int = 16
- d_model: int = 512
- nhead: int = 8
- dim_feedforward: int = 2048
- num_layers: int = 2
- vision_keep_rate: float = 0.5
- output_text_keep_rate: float = 0.8
- instruct_keep_rate: float = 0.9
- mask_loss_weight: float = 0.1
- output_text_len_for_training: int = 16
- instruct_len_for_training: int = 16
-
- def to_dict(self) -> dict:
- return asdict(self)
-
-def _coerce_sparse_config(sparse_cfg, hidden_size: int) -> SparseConfig:
- """
- Accepts None, dict, or SparseConfig and returns a fully-populated SparseConfig.
- Provides dimension defaults based on the model's hidden_size.
- """
- base = SparseConfig(
- d_model=hidden_size,
- nhead=max(1, hidden_size // 64),
- dim_feedforward=hidden_size * 4,
- )
-
- if sparse_cfg is None:
- return base
-
- if isinstance(sparse_cfg, SparseConfig):
- # ensure sensible dimension defaults if user kept dataclass defaults
- sc = SparseConfig(**vars(sparse_cfg))
- if sc.d_model == 512:
- sc.d_model = hidden_size
- if sc.nhead == 8:
- sc.nhead = max(1, hidden_size // 64)
- if sc.dim_feedforward == 2048:
- sc.dim_feedforward = hidden_size * 4
- return sc
-
- if isinstance(sparse_cfg, dict):
- for k, v in sparse_cfg.items():
- if hasattr(base, k):
- setattr(base, k, v)
- return base
-
- raise TypeError("config.sparse_config must be None, dict, or SparseConfig.")
-
-
-
-# Copied from transformers.models.llama.modeling_llama._get_unpad_data
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
-class Qwen2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Qwen2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
-class Qwen2RotaryEmbedding(nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
- )
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.rotate_half
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`):
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
- used to pass offsetted position ids when working with a KV-cache.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
-class Qwen2MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
-
-# Copied from transformers.models.llama.modeling_llama.repeat_kv
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-class Qwen2Attention(nn.Module):
- """
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
- and "Generating Long Sequences with Sparse Transformers".
- """
-
- def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
-
- self.rotary_emb = Qwen2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
- )
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
- "with a layer index."
- )
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
-
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
-
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Qwen2FlashAttention2(Qwen2Attention):
- """
- Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
- as the weights of the module stays untouched. The only required change would be on the forward pass
- where it needs to correctly call the public API of flash attention and deal with padding tokens
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
- config.max_window_layers layers.
- """
-
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ):
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
- )
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop("padding_mask")
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
- "with a layer index."
- )
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
-
- # Because the input can be padded, the absolute sequence length depends on the max position id.
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, "sliding_window", None) is not None
- and kv_seq_len > self.config.sliding_window
- and self.config.use_sliding_window
- )
-
- if not _flash_supports_window_size:
- logger.warning_once(
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
- " make sure to upgrade flash-attn library."
- )
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (
- getattr(self.config, "sliding_window", None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents
- ):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
- f" {past_key.shape}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
-
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate,
- use_sliding_windows=use_sliding_windows,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
- def _flash_attention_forward(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None,
- use_sliding_windows=False,
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
-
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`int`, *optional*):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- use_sliding_windows (`bool`, *optional*):
- Whether to activate sliding window attention.
- """
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Decide whether to use SWA or not by layer index.
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
- use_sliding_windows = False
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask, query_length
- )
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- if not use_sliding_windows:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window, self.config.sliding_window),
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
- else:
- if not use_sliding_windows:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window, self.config.sliding_window),
- )
-
- return attn_output
-
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
-
- # On the first iteration we need to properly re-create the padding mask
- # by slicing it on the proper place
- if kv_seq_len != attention_mask.shape[-1]:
- attention_mask_num_tokens = attention_mask.shape[-1]
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
-
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
-
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
-
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
- )
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-def softmax_with_policy(attn, policy, eps=1e-6):
- B, N, _ = policy.size()
- B, H, N, N = attn.size()
- attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1)
- eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(
- 1, 1, N, N
- )
- attn_policy = attn_policy + (1.0 - attn_policy) * eye
- max_att = torch.max(attn, dim=-1, keepdim=True)[0]
- attn = attn - max_att
- # attn = attn.exp_() * attn_policy
- # return attn / attn.sum(dim=-1, keepdim=True)
-
- # for stable training
- attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
- attn = (attn + eps / N) / (attn.sum(dim=-1, keepdim=True) + eps)
- return attn.type_as(max_att)
-
-
-# Efficient implementation equivalent to the following:
-def scaled_dot_product_attention_with_policy(
- query,
- key,
- value,
- attn_mask=None,
- dropout_p=0.0,
- is_causal=False,
- scale=None,
- policy=None,
-) -> torch.Tensor:
- B = query.size(0)
- L, S = query.size(-2), key.size(-2)
- scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
- if attn_mask is not None:
- attn_bias = torch.zeros_like(attn_mask, dtype=query.dtype, device=query.device)
- else:
- attn_bias = torch.zeros(B, 1, L, S, dtype=query.dtype, device=query.device)
- if is_causal:
- assert attn_mask is None
- temp_mask = torch.ones(B, 1, L, S, dtype=torch.bool, device=query.device).tril(
- diagonal=0
- )
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
- attn_bias.to(query.dtype)
-
- if attn_mask is not None:
- if attn_mask.dtype == torch.bool:
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
- else:
- attn_bias += attn_mask
- attn_weight = query @ key.transpose(-2, -1) * scale_factor
- attn_weight += attn_bias
- if policy is not None:
- attn_weight = softmax_with_policy(attn_weight, policy=policy)
- else:
- attn_weight = torch.softmax(attn_weight, dim=-1)
- attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
- return attn_weight @ value
-
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
-class DynamicQwen2SdpaAttention(Qwen2Attention):
- """
- Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
- `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
- SDPA API.
- """
-
- # Adapted from LlamaAttention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- policy: Optional[torch.Tensor] = None,
- kv_seq_len_for_position: Optional[int] = None,
- sparse_layer: Optional[int] = None,
- text_decision: Optional[List[bool]] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
- key_states = key_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
- value_states = value_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
-
- # ----------------------------------------------------------#
- kv_seq_len = key_states.shape[-2]
- if kv_seq_len_for_position is None:
- kv_seq_len_for_position = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- if sparse_layer is not None and self.layer_idx >= sparse_layer:
- kv_seq_len_for_position += (
- past_key_value.get_usable_length(
- kv_seq_len_for_position, sparse_layer - 1
- )
- - kv_seq_len_for_position
- )
- else:
- kv_seq_len_for_position += past_key_value.get_usable_length(
- kv_seq_len_for_position, self.layer_idx
- )
- # ----------------------------------------------------------#
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len_for_position)
-
- # print(f"cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids}")
-
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids
- )
-
- if past_key_value is not None:
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
-
- # ----------------------------------------------------------#
- # key_states, value_states = past_key_value.update(
- # key_states,
- # value_states,
- # self.layer_idx,
- # cache_kwargs,
- # text_decision,
- # )
-
- if text_decision is not None:
- # Fix one bug
- temp_key_states, temp_value_states = key_states, value_states
- key_states, value_states = past_key_value.get_cache(
- key_states,
- value_states,
- self.layer_idx,
- cache_kwargs,
- )
- _, _ = past_key_value.update(
- temp_key_states,
- temp_value_states,
- self.layer_idx,
- cache_kwargs,
- text_decision,
- )
- else:
- key_states, value_states = past_key_value.update(
- key_states,
- value_states,
- self.layer_idx,
- cache_kwargs,
- )
- # ----------------------------------------------------------#
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and attention_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- if policy is not None:
- attn_output = scaled_dot_product_attention_with_policy(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- dropout_p=self.attention_dropout if self.training else 0.0,
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
- policy=policy,
- )
- else:
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- dropout_p=self.attention_dropout if self.training else 0.0,
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-QWEN2_ATTENTION_CLASSES = {
- "eager": Qwen2Attention,
- "flash_attention_2": Qwen2FlashAttention2,
- "sdpa": DynamicQwen2SdpaAttention,
-}
-
-
-class DynamicQwen2DecoderLayer(nn.Module):
- def __init__(self, config: Qwen2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
- logger.warning_once(
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
- "unexpected results may be encountered."
- )
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
-
- self.mlp = Qwen2MLP(config)
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- policy: Optional[torch.Tensor] = None,
- kv_seq_len_for_position: Optional[int] = None,
- sparse_layer: Optional[int] = None,
- text_decision: Optional[List[bool]] = None,
- **kwargs,
- ) -> Tuple[
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
- ]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
- )
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
- # print(f"hidden_states: {hidden_states.shape}, attention_mask: {attention_mask.shape if attention_mask is not None else None}, position_ids: {position_ids.shape if position_ids is not None else None}")
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- policy=policy,
- kv_seq_len_for_position=kv_seq_len_for_position,
- sparse_layer=sparse_layer,
- text_decision=text_decision,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-class VisionPredictor(nn.Module):
- def __init__(
- self,
- input_dim=4096,
- d_model=512,
- nhead=8,
- dim_feedforward=2048,
- num_layers=2,
- ):
- super().__init__()
- self.input_dim = input_dim
- self.d_model = d_model
- self.nhead = nhead
- self.dim_feedforward = dim_feedforward
- self.num_layers = num_layers
-
- self.down_mlp = nn.Sequential(
- nn.LayerNorm(self.input_dim),
- nn.Linear(self.input_dim, self.d_model),
- nn.GELU(),
- )
- self.transformer = nn.Sequential(
- *[
- SelfTransformerEncoderBlock(
- dim=self.d_model,
- num_heads=self.nhead,
- mlp_ratio=self.dim_feedforward / self.d_model,
- )
- for _ in range(self.num_layers)
- ]
- )
- self.output_mlp = nn.Sequential(
- nn.Linear(self.d_model, self.d_model // 2),
- nn.GELU(),
- nn.Linear(self.d_model // 2, self.d_model // 4),
- nn.GELU(),
- nn.Linear(self.d_model // 4, 2),
- # nn.LogSoftmax(dim=-1),
- )
- def forward(self, x, image_policy) -> torch.FloatTensor:
- # # print('x', x.mean(), x.shape, x.dtype)
-
- # # --- DEBUGGING NANs ---
- # # 1. Check if input is already corrupt
- # if torch.isnan(x).any():
- # print("ERROR: NaN detected in input 'x' before down_mlp.")
- # # This is critical. If this fires, the problem is in your data pipeline.
-
- # # 2. Check the scale/magnitude of the input values
- # # If max is very large (e.g., > 1e4), it will likely cause overflow in LayerNorm.
- # if x.max() > 1e4:
- # print(f"WARNING: Input 'x' has a very large max value: {x.max().item()}. This can cause NaNs in LayerNorm.")
-
- # --- Original Code ---
- new_image_x = self.down_mlp(x)
-
- # print('new_image_x', new_image_x.mean(), new_image_x.shape, new_image_x.dtype)
- # print('image_policy', image_policy, image_policy.shape, image_policy.dtype)
-
- before_transformer = new_image_x * image_policy
- # print('before_transformer', before_transformer.mean(), before_transformer.shape, before_transformer.dtype)
- new_x = self.transformer(before_transformer)
- # print('image_policy', image_policy.mean(), image_policy.mean(dim=1).dtype)
- # print('new_x', new_x.mean(), new_x.shape, new_x.dtype)
- B, N, C = new_x.size()
- local_x = new_x[:, :, : C // 2]
-
- # Calculate the sum of the policy weights
- policy_sum = torch.sum(image_policy, dim=1, keepdim=True)
-
- # Calculate the weighted sum for the global features
- global_x_num = (new_x[:, :, C // 2 :] * image_policy).sum(dim=1, keepdim=True)
- # print('global_x_num', global_x_num.mean(), global_x_num.shape, global_x_num.dtype)
- # **THE FIX**: Clamp the denominator to a small positive value to prevent division by zero
- global_x = global_x_num / torch.clamp(policy_sum, min=1e-6)
- # print('global_x', global_x.mean(), global_x.shape, global_x.dtype)
- new_x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1)
- # print('new_x', new_x.mean(), new_x.shape, new_x.dtype)
- predict = self.output_mlp(new_x)
- return predict
-
- # def forward(self, x, image_policy) -> torch.FloatTensor:
- # predict = []
- # new_image_x = self.down_mlp(x)
- # new_x = self.transformer(new_image_x * image_policy)
-
- # B, N, C = new_x.size()
- # local_x = new_x[:, :, : C // 2]
- # print('image_policy', image_policy.mean(), image_policy.mean(dim=1).dtype)
- # print('local_x', local_x.mean(), local_x.shape, local_x.dtype)
- # global_x = (new_x[:, :, C // 2 :] * image_policy).sum(
- # dim=1, keepdim=True
- # ) / torch.sum(image_policy, dim=1, keepdim=True)
- # print('global_x', global_x.mean(), global_x.shape, global_x.dtype)
- # new_x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1)
- # print('new_x', new_x.mean(), new_x.shape, new_x.dtype)
- # predict = self.output_mlp(new_x)
- # print('predict', predict.mean(), predict.shape, predict.dtype)
- # return predict
-
-
-class TextPredictor(nn.Module):
- def __init__(
- self,
- input_dim=4096,
- d_model=512,
- nhead=8,
- dim_feedforward=2048,
- num_layers=2,
- ):
- super().__init__()
- self.input_dim = input_dim
- self.d_model = d_model
- self.output_mlp = nn.Sequential(
- nn.LayerNorm(self.input_dim),
- nn.Linear(self.input_dim, self.d_model),
- nn.GELU(),
- nn.Linear(self.d_model, self.d_model // 2),
- nn.GELU(),
- nn.Linear(self.d_model // 2, self.d_model // 4),
- nn.GELU(),
- nn.Linear(self.d_model // 4, 2),
- )
-
- def forward(self, x) -> torch.FloatTensor:
- predict = self.output_mlp(x)
- return predict
-
-
-def weight_merging(select_token, unselect_token):
- pass
-
-
-def ste_argmax(logits: torch.Tensor, dim: int = -1):
- y_soft = logits
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(
- logits, memory_format=torch.legacy_contiguous_format
- ).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- return ret
-
-
-def ste_topk(logits: torch.Tensor, k: int, dim: int = -1):
- y_soft = logits
- index = y_soft.topk(k=k, dim=dim, largest=True, sorted=False)[1]
- y_hard = torch.zeros_like(
- logits, memory_format=torch.legacy_contiguous_format
- ).scatter_(dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- return ret
-
-@dataclass
-class Qwen2DynamicModelOutputWithPast(BaseModelOutputWithPast):
- image_masks: Optional[List[torch.Tensor]] = None
- output_text_masks_batch_list: Optional[List[List[torch.Tensor]]] = None
- instruct_masks_batch_list: Optional[List[List[torch.Tensor]]] = None
- image_score_predictor_logits: Optional[List[torch.Tensor]] = None
-
-
-QWEN2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`Qwen2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
- QWEN2_START_DOCSTRING,
-)
-@BUILDER.register_module()
-class DynamicQwen2PreTrainedModel(PreTrainedModel):
- config_class = Qwen2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- # _no_split_modules = ["Qwen2DecoderLayer"]
- _no_split_modules = ["DynamicLlamaDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-QWEN2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
-
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance;
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
-
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-@add_start_docstrings(
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
- QWEN2_START_DOCSTRING,
-)
-class DynamicQwen2Model(DynamicQwen2PreTrainedModel):
- def __init__(self, config: "Qwen2Config"):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [DynamicQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self._attn_implementation = config._attn_implementation
- self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
-
- # ---------------- NEW: dataclass-based sparse config (single source of truth) ----------------
- # print('config.sparse_config', getattr(config, "sparse_config", None))
- self.sparse_config: SparseConfig = _coerce_sparse_config(
- getattr(config, "sparse_config", None), config.hidden_size
- )
- # --------------------------------------------------------------------------------------------
-
- # Build sparsity predictors from dataclass
- if self.sparse_config.use_vision_predictor:
- self.image_score_predictor = VisionPredictor(
- input_dim=config.hidden_size,
- d_model=self.sparse_config.d_model,
- nhead=self.sparse_config.nhead,
- dim_feedforward=self.sparse_config.dim_feedforward,
- num_layers=self.sparse_config.num_layers,
- )
-
- if self.sparse_config.use_text_predictor:
- if self.sparse_config.use_output_text_predictor:
- self.output_text_score_predictor = TextPredictor(
- input_dim=config.hidden_size,
- d_model=self.sparse_config.d_model,
- nhead=self.sparse_config.nhead,
- dim_feedforward=self.sparse_config.dim_feedforward,
- num_layers=self.sparse_config.num_layers,
- )
- if self.sparse_config.use_instruct_predictor:
- self.instruct_score_predictor = TextPredictor(
- input_dim=config.hidden_size,
- d_model=self.sparse_config.d_model,
- nhead=self.sparse_config.nhead,
- dim_feedforward=self.sparse_config.dim_feedforward,
- num_layers=self.sparse_config.num_layers,
- )
-
- self.gumbel_tau = 1.0
- self.answer_indice = None
-
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- # where the spans live
- input_embeds_indices: Optional[List[dict]] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- # print('use_cache', use_cache, 'past_key_values', past_key_values)
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- use_cache = use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape[:2]
- elif inputs_embeds is not None:
- batch_size, seq_length = inputs_embeds.shape[:2]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
- if self.gradient_checkpointing and self.training and use_cache:
- use_cache = False # gradient checkpointing incompatible with use_cache
-
- past_key_values_length = 0
- # print('use_cache', use_cache, 'past_key_values', past_key_values)
- if use_cache:
- # print("Using cache for generation")
- use_legacy_cache = not isinstance(past_key_values, Cache)
- # print("Using legacy cache:", use_legacy_cache)
- if use_legacy_cache:
- past_key_values = DynamicCachePlus.from_legacy_cache(past_key_values)
- past_key_values_length = past_key_values.get_usable_length(seq_length)
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
- ).unsqueeze(0)
- else:
- position_ids = position_ids.view(position_ids.shape[0], -1).long()
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- # NOTE: Flash attention generation requirement remains unchanged
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
- if is_padding_right:
- raise ValueError(
- "For Flash Attention batched generation use padding_side='left' before tokenizing the input."
- )
- # print('inputs_embeds', inputs_embeds.mean(), inputs_embeds.shape, inputs_embeds.dtype)
- hidden_states = inputs_embeds
-
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
-
- # dynamic sparsity runtime vars
- B = hidden_states.shape[0]
- init_n = hidden_states.shape[1]
- policy = None
- image_masks: List[torch.Tensor] = []
- output_text_masks_batch_list: List[List[torch.Tensor]] = []
- instruct_masks_batch_list: List[List[torch.Tensor]] = []
- image_score_predictor_logits: List[torch.Tensor] = []
- text_decision = None
-
- # print("input_embeds_indices", input_embeds_indices)
- # print("model hidden_states shape", hidden_states.shape)
-
- if (
- self.sparse_config.use_vision_predictor
- and input_embeds_indices is not None
- and hidden_states.shape[0] == len(input_embeds_indices)
- ):
- init_image_n = input_embeds_indices[0]["image"][1] - input_embeds_indices[0]["image"][0]
- image_prev_decision = torch.ones(B, init_image_n, 1, dtype=hidden_states.dtype, device=hidden_states.device)
-
- for i, decoder_layer in enumerate(self.layers):
- if use_cache:
- past_key_values_length = past_key_values.get_usable_length(seq_length, i)
-
- if self._attn_implementation == "flash_attention_2":
- layer_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- elif self._attn_implementation == "sdpa" and not output_attentions:
- layer_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
- )
- else:
- layer_attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length,
- sliding_window=self.config.sliding_window,
- )
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- # print('sparse_config', self.sparse_config)
-
- # Vision sparsifier
- if (
- self.sparse_config.use_vision_predictor
- and i == self.sparse_config.sparse_layer
- and input_embeds_indices is not None
- and hidden_states.shape[0] == len(input_embeds_indices)
- and hidden_states.shape[1] > 1
- # and (not use_cache or not past_key_values_length)
- ):
-
- per_img = []
- for b in range(B):
- s, e = input_embeds_indices[b]["image"]
- per_img.append(hidden_states[b, s:e, :])
-
- max_len = max(x.shape[0] for x in per_img)
- padded = [
- torch.cat([x, torch.zeros((max_len - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)], dim=0)
- for x in per_img
- ]
- new_image_x = torch.stack(padded, dim=0)
- # print('new_image_x value', new_image_x.mean().item(), 'shape', new_image_x.shape)
- # _finite_stats("new_image_x", new_image_x)
-
- # print('dypte of new_image_x', new_image_x.dtype, 'device', new_image_x.device)
- # print('image_prev_decision value', image_prev_decision.mean().item(), 'shape', image_prev_decision.shape)
- # print('current layer', i, 'sparse layer', self.sparse_config.sparse_layer)
- # print('new image_x shape', new_image_x.shape, 'image_prev_decision shape', image_prev_decision.shape)
- logits = self.image_score_predictor(new_image_x, image_prev_decision).reshape(B, -1, 2)
- # # print('logits value', logits.mean().item(), 'shape', logits.shape)
- # if not torch.isfinite(logits).all():
- # print("NaN in logits, dumping predictor params...")
- # for n,p in self.image_score_predictor.named_parameters():
- # if p is None: continue
- # if not torch.isfinite(p).all():
- # print(f"param {n} has non-finite values")
- # if p.grad is not None and not torch.isfinite(p.grad).all():
- # print(f"grad {n} has non-finite values")
- image_score_predictor_logits.append(logits)
- pred_score = F.log_softmax(logits, dim=-1)
-
- if self.training:
- keep_hard = (
- F.gumbel_softmax(pred_score, tau=self.gumbel_tau, hard=True)[:, :, 0:1] * image_prev_decision
- )
- # print('keep_hard value', keep_hard.mean().item(), 'pred_score value', pred_score.mean().item())
- # print('keep_hard.shape', keep_hard.shape, 'pred_score.shape', pred_score.shape)
-
- image_masks.append(keep_hard.reshape(B, init_image_n))
-
- left_ones = torch.ones(
- B, input_embeds_indices[0]["image"][0], 1, dtype=hidden_states.dtype, device=hidden_states.device
- )
- right_ones = torch.ones(
- B, init_n - input_embeds_indices[0]["image"][1], 1, dtype=hidden_states.dtype, device=hidden_states.device
- )
- policy = torch.cat([left_ones, keep_hard, right_ones], dim=1)
- image_prev_decision = keep_hard
- else:
- keep_score = pred_score[:, :, 0]
- num_keep = int(init_image_n * self.sparse_config.vision_keep_rate)
- keep_index, _ = torch.sort(
- torch.argsort(keep_score, dim=1, descending=True)[:, :num_keep], dim=1, descending=False
- )
-
- s0, e0 = input_embeds_indices[0]["image"]
- image_h = hidden_states[:, s0:e0, :]
- left_h = hidden_states[:, :s0, :]
- right_h = hidden_states[:, e0:, :]
-
- kept_img = image_h.gather(1, keep_index[..., None].expand(image_h.shape[0], num_keep, image_h.shape[2]))
- hidden_states = torch.cat([left_h, kept_img, right_h], dim=1)
-
- image_prev_decision = image_prev_decision.gather(
- 1, keep_index[..., None].expand(image_prev_decision.shape[0], num_keep, image_prev_decision.shape[2])
- )
- policy = None # physically pruned
-
- keep_index_for_image_position = (keep_index + s0).to(dtype=position_ids.dtype, device=position_ids.device)
- keep_left_pos = torch.arange(0, s0, device=position_ids.device, dtype=position_ids.dtype).repeat(B, 1)
- keep_right_pos = torch.arange(e0, init_n, device=position_ids.device, dtype=position_ids.dtype).repeat(B, 1)
- position_ids = torch.cat([keep_left_pos, keep_index_for_image_position, keep_right_pos], dim=1)
-
- removed = init_image_n - num_keep
- for idx in input_embeds_indices:
- idx["image"][1] -= removed
- idx["instruct"][0] -= removed
- idx["instruct"][1] -= removed
- idx["last_instruct"][0] -= removed
- idx["last_instruct"][1] -= removed
- idx["answer"][0] -= removed
- idx["answer"][1] -= removed
-
- # Text sparsifier
- if self.sparse_config.use_text_predictor and i == self.sparse_config.sparse_layer:
- # Training path: write hard decisions into `policy`
- if (
- input_embeds_indices is not None
- and hidden_states.shape[0] == len(input_embeds_indices)
- and self.training
- ):
- # print('DynamicQwen2Model: Using text predictor in training')
- # output text (answer)
- if self.sparse_config.use_output_text_predictor:
- max_ans_len = max(idx["answer"][1] - idx["answer"][0] for idx in input_embeds_indices)
- prev_dec = torch.ones(B, max_ans_len, 1, dtype=hidden_states.dtype, device=hidden_states.device)
- for b in range(B):
- true_len = input_embeds_indices[b]["answer"][1] - input_embeds_indices[b]["answer"][0]
- if true_len < max_ans_len:
- prev_dec[b, true_len:, :] = 0
- per_ans = [
- hidden_states[b, input_embeds_indices[b]["answer"][0] : input_embeds_indices[b]["answer"][1], :]
- for b in range(B)
- ]
- pad_len = max(x.shape[0] for x in per_ans)
- stacked = torch.stack([
- torch.cat([x, torch.zeros((pad_len - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)], dim=0)
- for x in per_ans
- ], dim=0)
- logits = self.output_text_score_predictor(stacked).reshape(B, -1, 2)
- pred = F.log_softmax(logits, dim=-1)
- keep_hard = F.gumbel_softmax(pred, tau=self.gumbel_tau, hard=True)[:, :, 0:1] * prev_dec
-
- for b in range(B):
- length = input_embeds_indices[b]["answer"][1] - input_embeds_indices[b]["answer"][0]
- if length < self.sparse_config.output_text_len_for_training:
- keep_hard[b, :length, :] = 1.0
-
- output_text_masks_batch_list.append([
- keep_hard[b, : (input_embeds_indices[b]["answer"][1] - input_embeds_indices[b]["answer"][0]), :].reshape(
- input_embeds_indices[b]["answer"][1] - input_embeds_indices[b]["answer"][0]
- )
- for b in range(B)
- ])
-
- if policy is not None:
- for b in range(B):
- s, e = input_embeds_indices[b]["answer"]
- policy[b, s:e, :] = keep_hard[b, : (e - s), :]
-
- # instruction (last_instruct)
- if self.sparse_config.use_instruct_predictor:
- max_ins_len = max(idx["last_instruct"][1] - idx["last_instruct"][0] for idx in input_embeds_indices)
- prev_dec = torch.ones(B, max_ins_len, 1, dtype=hidden_states.dtype, device=hidden_states.device)
- for b in range(B):
- true_len = input_embeds_indices[b]["last_instruct"][1] - input_embeds_indices[b]["last_instruct"][0]
- if true_len < max_ins_len:
- prev_dec[b, true_len:, :] = 0
-
- per_ins = [
- hidden_states[b, input_embeds_indices[b]["last_instruct"][0] : input_embeds_indices[b]["last_instruct"][1], :]
- for b in range(B)
- ]
- pad_len = max(x.shape[0] for x in per_ins)
- stacked = torch.stack([
- torch.cat([x, torch.zeros((pad_len - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)], dim=0)
- for x in per_ins
- ], dim=0)
- logits = self.instruct_score_predictor(stacked).reshape(B, -1, 2)
- pred = F.log_softmax(logits, dim=-1)
- keep_hard = F.gumbel_softmax(pred, tau=self.gumbel_tau, hard=True)[:, :, 0:1] * prev_dec
-
- for b in range(B):
- length = input_embeds_indices[b]["last_instruct"][1] - input_embeds_indices[b]["last_instruct"][0]
- if length < self.sparse_config.instruct_len_for_training:
- keep_hard[b, :length, :] = 1.0
-
- instruct_masks_batch_list.append([
- keep_hard[b, : (input_embeds_indices[b]["last_instruct"][1] - input_embeds_indices[b]["last_instruct"][0]), :].reshape(
- input_embeds_indices[b]["last_instruct"][1] - input_embeds_indices[b]["last_instruct"][0]
- )
- for b in range(B)
- ])
-
- if policy is not None:
- for b in range(B):
- s, e = input_embeds_indices[b]["last_instruct"]
- policy[b, s:e, :] = keep_hard[b, : (e - s), :]
-
- # Prefill inference (no KV), first instruction pruning
- elif (
- not self.training
- and not past_key_values_length
- and input_embeds_indices is not None
- and hidden_states.shape[0] == len(input_embeds_indices)
- and self.sparse_config.use_instruct_predictor
- ):
- assert B == 1, "Using text predictor in prefill currently assumes batch size = 1"
- s, e = input_embeds_indices[0]["last_instruct"]
- instruct_h = hidden_states[:, s : e - 1, :]
- logits = self.instruct_score_predictor(instruct_h).reshape(B, -1, 2)
- keep_idx = torch.where(logits[0, :, 0] > logits[0, :, 1])[0].unsqueeze(0)
-
- left_h = hidden_states[:, :s, :]
- right_h = hidden_states[:, e - 1 :, :]
- kept = instruct_h.gather(1, keep_idx[..., None].expand(instruct_h.shape[0], -1, instruct_h.shape[2]))
- hidden_states = torch.cat([left_h, kept, right_h], dim=1)
-
- keep_pos = position_ids[:, s : e - 1].gather(1, keep_idx)
- left_pos = position_ids[:, :s]
- right_pos = position_ids[:, e - 1 :]
- position_ids = torch.cat([left_pos, keep_pos, right_pos], dim=1)
-
- removed = (e - 1 - s) - keep_idx.shape[1]
- for idx in input_embeds_indices:
- idx["instruct"][1] -= removed
- idx["last_instruct"][1] -= removed
- idx["answer"][0] -= removed
- idx["answer"][1] -= removed
-
- # Decoding with KV: per-token keep decision for outputs
- elif (
- not self.training
- and past_key_values_length
- and hidden_states.shape[1] == 1
- and self.sparse_config.use_output_text_predictor
- ):
- logits = self.output_text_score_predictor(hidden_states).reshape(B, -1, 2)
- text_decision = (logits[:, :, 0] > logits[:, :, 1])
-
- # Full-seq inference (no KV): prune answer span without cache
- elif (
- not self.training
- and input_embeds_indices is not None
- and hidden_states.shape[0] == len(input_embeds_indices)
- and self.sparse_config.use_output_text_predictor
- and not use_cache
- ):
- if self.answer_indice is None:
- self.answer_indice = input_embeds_indices[0]["instruct"][1]
- s_ans = self.answer_indice
- out_h = hidden_states[:, s_ans : -1, :]
- logits = self.output_text_score_predictor(out_h).reshape(B, -1, 2)
- decision = (logits[:, :, 0] > logits[:, :, 1])
- num_keep = decision.sum(dim=1).max()
- keep_score = logits[:, :, 0]
- keep_idx, _ = torch.sort(
- torch.argsort(keep_score, dim=1, descending=True)[:, :num_keep], dim=1, descending=False
- )
-
- left_h = hidden_states[:, :s_ans, :]
- right_h = hidden_states[:, -1:, :]
- kept = out_h.gather(1, keep_idx[..., None].expand(out_h.shape[0], num_keep, out_h.shape[2]))
- hidden_states = torch.cat([left_h, kept, right_h], dim=1)
- policy = None
-
- keep_pos = position_ids[:, s_ans : -1].gather(1, keep_idx)
- left_pos = position_ids[:, :s_ans]
- right_pos = position_ids[:, -1:]
- position_ids = torch.cat([left_pos, keep_pos, right_pos], dim=1)
-
- # Call the layer (pass policy/text_decision)
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=layer_attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- policy=policy,
- kv_seq_len_for_position=init_n,
- sparse_layer=self.sparse_config.sparse_layer,
- text_decision=text_decision,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = None
- if use_cache:
- next_cache = next_decoder_cache.to_legacy_cache() if not isinstance(past_key_values, Cache) else next_decoder_cache
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
-
- return Qwen2DynamicModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- image_masks=image_masks,
- output_text_masks_batch_list=output_text_masks_batch_list,
- instruct_masks_batch_list=instruct_masks_batch_list,
- image_score_predictor_logits=image_score_predictor_logits,
- )
-
-@BUILDER.register_module()
-class DynamicQwen2ForCausalLM(DynamicQwen2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
-
- def __init__(self, config, sparse_config: Optional[SparseConfig] = None):
- super().__init__(config)
- config.sparse_config = sparse_config if sparse_config is not None else SparseConfig()
- self.model = DynamicQwen2Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.post_init()
-
- def update_sparse_config(self, sparse_config: SparseConfig):
- """
- Update the sparse configuration of the model.
- """
- self.sparse_config = sparse_config
- self.model.sparse_config = sparse_config
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- input_embeds_indices: Optional[List[dict]] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
-
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
-
- >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # print("DynamicQwen2ForCausalLM forward() called with input_embeds_indices:", input_embeds_indices)
- # print('use_cache', use_cache, 'past_key_values', past_key_values)
- # # print('input_ids', input_ids, 'attention_mask', attention_mask, 'position_ids', position_ids)
- # print('inputs_embeds', inputs_embeds, 'labels', labels)
-
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- input_embeds_indices=input_embeds_indices,
- )
-
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states).float()
-
- loss = None
- if labels is not None:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1).to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- # ---------------------- NEW: use dataclass fields for sparsity losses ----------------------
- sc = self.model.sparse_config
-
- if outputs.image_masks is not None and len(outputs.image_masks):
- image_mask_loss = 0.0
- for mask in outputs.image_masks:
- batch_ratio = mask.mean(dim=1)
- # print(f"Image mask batch ratio: {batch_ratio}")
- image_mask_loss = image_mask_loss + ((sc.vision_keep_rate - batch_ratio) ** 2).mean()
- loss = loss + sc.mask_loss_weight * image_mask_loss
-
- if outputs.output_text_masks_batch_list is not None and len(outputs.output_text_masks_batch_list):
- output_text_mask_loss = 0.0
- for mask_batch_list in outputs.output_text_masks_batch_list:
- batch_ratio = torch.stack([mask.mean() for mask in mask_batch_list])
- target_batch_ratio = torch.tensor(
- [
- (sc.output_text_keep_rate if mask.shape[0] >= sc.output_text_len_for_training else mask.mean().item())
- for mask in mask_batch_list
- ],
- dtype=batch_ratio.dtype, device=batch_ratio.device
- )
- output_text_mask_loss = output_text_mask_loss + ((target_batch_ratio - batch_ratio) ** 2).mean()
- loss = loss + sc.mask_loss_weight * output_text_mask_loss
-
- if outputs.instruct_masks_batch_list is not None and len(outputs.instruct_masks_batch_list):
- instruct_mask_loss = 0.0
- for mask_batch_list in outputs.instruct_masks_batch_list:
- batch_ratio = torch.stack([mask.mean() for mask in mask_batch_list])
- target_batch_ratio = torch.tensor(
- [
- (sc.instruct_keep_rate if mask.shape[0] >= sc.instruct_len_for_training else mask.mean().item())
- for mask in mask_batch_list
- ],
- dtype=batch_ratio.dtype, device=batch_ratio.device
- )
- instruct_mask_loss = instruct_mask_loss + ((target_batch_ratio - batch_ratio) ** 2).mean()
- loss = loss + sc.mask_loss_weight * instruct_mask_loss
- # ------------------------------------------------------------------------------------------
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- # def prepare_inputs_for_generation(
- # self, input_ids,
- # past_key_values=None,
- # attention_mask=None,
- # inputs_embeds=None,
- # input_embeds_indices=None, **kwargs
- # ):
-
- # if past_key_values is not None:
- # if isinstance(past_key_values, Cache):
- # cache_length = past_key_values.get_seq_length()
- # past_length = past_key_values.seen_tokens
- # max_cache_length = past_key_values.get_max_length()
- # else:
- # cache_length = past_length = past_key_values[0][0].shape[2]
- # max_cache_length = None
-
- # if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
- # input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
- # elif past_length < input_ids.shape[1]:
- # input_ids = input_ids[:, past_length:]
-
- # if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
- # attention_mask = attention_mask[:, -max_cache_length:]
-
- # position_ids = kwargs.get("position_ids", None)
- # if attention_mask is not None and position_ids is None:
- # position_ids = attention_mask.long().cumsum(-1) - 1
- # position_ids.masked_fill_(attention_mask == 0, 1)
- # if past_key_values:
- # position_ids = position_ids[:, -input_ids.shape[1]:]
-
- # if inputs_embeds is not None and past_key_values is None:
- # model_inputs = {"inputs_embeds": inputs_embeds}
- # else:
- # print("Preparing inputs for generation with input_ids:", input_ids)
- # model_inputs = {"input_ids": input_ids}
-
-
- # # print('prepare_inputs_for_generation input_ids', input_ids, 'input_embeds', inputs_embeds, 'past_key_values', past_key_values)
-
- # # print('prepare_use_cache', kwargs.get("use_cache"))
-
- # model_inputs.update(
- # {
- # "position_ids": position_ids,
- # "past_key_values": past_key_values,
- # "use_cache": False,
- # "attention_mask": attention_mask,
- # "input_embeds_indices": input_embeds_indices,
- # }
- # )
- # for key in model_inputs.keys():
- # print(f"prepare_inputs_for_generation: {key} = {model_inputs[key]}")
- # return model_inputs
- def prepare_inputs_for_generation(
- self,
- input_ids=None,
- inputs_embeds=None,
- past_key_values=None,
- attention_mask=None,
- position_ids=None,
- input_embeds_indices=None,
- cache_position=None, # 4.48 passes this; we accept but DO NOT forward to .forward()
- use_cache=None,
- **kwargs,
- ):
- """
- Bridge between HF generate() and the model's forward:
-
- • Step 0 (no/empty cache OR HF provided empty input_ids while using inputs_embeds):
- - Keep and pass the packed `inputs_embeds` (your multimodal embeddings).
- - Synthesize an attention mask and position_ids if missing.
-
- • Steps 1+ (cache is populated):
- - Pass only the last token `input_ids` and rely on the cache.
- - Slice attention_mask / position_ids to the last step if provided.
-
- Compatible with Transformers 4.48.x while the model's forward remains 37.x-style
- (i.e., no `cache_position` arg).
- """
-
- # ---- helper: determine if cache is effectively empty (step 0) ----
- def _cache_seq_len(cache):
- # Your custom cache (DynamicCachePlus) may not implement get_seq_length,
- # but exposes `seen_tokens`. Try both.
- try:
- return cache.get_seq_length()
- except Exception:
- try:
- return getattr(cache, "seen_tokens", 0)
- except Exception:
- return 0
-
- empty_cache = (past_key_values is None) or (_cache_seq_len(past_key_values) == 0)
-
- # HF can pass input_ids of shape (B, 0) at step 0 when generation is driven by inputs_embeds.
- input_ids_is_empty = (input_ids is None) or (hasattr(input_ids, "numel") and input_ids.numel() == 0)
-
- # We treat it as the first step if we have embeddings and either cache is empty or input_ids is empty.
- first_step = (inputs_embeds is not None) and (empty_cache or input_ids_is_empty)
-
- # ---- pick a device for any tensors we need to synthesize ----
- if inputs_embeds is not None:
- dev = inputs_embeds.device
- elif input_ids is not None:
- dev = input_ids.device
- else:
- dev = next(self.parameters()).device # fallback
-
- # ---- if we already have an attention mask but no position_ids, derive them ----
- if attention_mask is not None and position_ids is None:
- # cumulative positions where mask == 1
- position_ids = attention_mask.long().cumsum(-1) - 1
- # pad positions (mask==0) get a dummy 1 (won't be used)
- position_ids.masked_fill_(attention_mask == 0, 1)
-
- model_inputs = {}
-
- if first_step:
- # -------- Step 0: feed the full packed multimodal embeddings --------
- model_inputs["inputs_embeds"] = inputs_embeds
-
- # Synthesize attention_mask if caller didn't provide one
- if attention_mask is None:
- bsz, seqlen = inputs_embeds.shape[:2]
- attention_mask = torch.ones((bsz, seqlen), dtype=torch.long, device=dev)
-
- # If still no position_ids, derive from mask now
- if position_ids is None:
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
-
- else:
- # -------- Steps 1+: rely on KV cache; only feed the last token id --------
- if input_ids is not None and input_ids.shape[1] > 1:
- input_ids = input_ids[:, -1:]
- model_inputs["input_ids"] = input_ids
-
- # When attention_mask/position_ids exist, slice them to the latest token
- if attention_mask is not None and attention_mask.shape[1] > 1:
- attention_mask = attention_mask[:, -1:]
- if position_ids is not None and position_ids.shape[1] > 1:
- position_ids = position_ids[:, -1:]
-
- # ---- common fields passed through ----
- if attention_mask is not None:
- model_inputs["attention_mask"] = attention_mask
- if position_ids is not None:
- model_inputs["position_ids"] = position_ids
- if input_embeds_indices is not None:
- model_inputs["input_embeds_indices"] = input_embeds_indices
- if past_key_values is not None:
- model_inputs["past_key_values"] = past_key_values
-
- # IMPORTANT: do NOT pass `cache_position` -> your .forward() doesn't accept it
- # (Transformers 4.48 will still pass it to this method; we simply ignore it.)
-
- # Default to caching during generation unless the caller explicitly disables it.
- model_inputs["use_cache"] = True if use_cache is None else use_cache
-
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
- )
- return reordered_past
\ No newline at end of file
diff --git a/code/xtuner/model/fastv/__init__.py b/code/xtuner/model/fastv/__init__.py
deleted file mode 100644
index 853bf496676a8d7149fe921d0b2406cf7fca7812..0000000000000000000000000000000000000000
--- a/code/xtuner/model/fastv/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .fastv_qwen import Qwen2Model as Qwen25ModelFastV, Qwen2ForCausalLM as Qwen25ForCausalLMFastV
-
-__all__ = ["Qwen25ModelFastV", "Qwen25ForCausalLMFastV"]
diff --git a/code/xtuner/model/fastv/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/fastv/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 6ef82faf3ac804424236a9cf8d03633c4eed581f..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/fastv/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/fastv/__pycache__/fastv_qwen.cpython-311.pyc b/code/xtuner/model/fastv/__pycache__/fastv_qwen.cpython-311.pyc
deleted file mode 100644
index 6b6b83878724a695cbe838837733b3bec3cefdaa..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/fastv/__pycache__/fastv_qwen.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/fastv/cache_utils.py b/code/xtuner/model/fastv/cache_utils.py
deleted file mode 100644
index 4cc873f4e7242e0afefe719bbe147f28a4ca1833..0000000000000000000000000000000000000000
--- a/code/xtuner/model/fastv/cache_utils.py
+++ /dev/null
@@ -1,320 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple
-
-import torch
-
-
-class Cache:
- """
- Base, abstract class for all caches. The actual data structure is specific to each subclass.
- """
-
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
- cache to be created.
-
- Return:
- A tuple containing the updated key and value states.
- """
- raise NotImplementedError("Make sure to implement `update` in a subclass.")
-
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- raise NotImplementedError(
- "Make sure to implement `get_seq_length` in a subclass."
- )
-
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states, if there is any."""
- raise NotImplementedError(
- "Make sure to implement `get_max_length` in a subclass."
- )
-
- def get_usable_length(
- self, new_seq_length: int, layer_idx: Optional[int] = 0
- ) -> int:
- """Given the sequence length of the new inputs, returns the usable length of the cache."""
- # Cache without size limit -> all cache is usable
- # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
- # length, we will need to evict part of the cache (and thus not all cache is usable)
- max_length = self.get_max_length()
- previous_seq_length = self.get_seq_length(layer_idx)
- if max_length is not None and previous_seq_length + new_seq_length > max_length:
- return max_length - new_seq_length
- return previous_seq_length
-
-
-class DynamicCachePlus(Cache):
- """
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
-
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
- `[batch_size, num_heads, seq_len, head_dim]`.
- """
-
- def __init__(self) -> None:
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- self.seen_tokens = (
- 0 # Used in `generate` to keep tally of how many tokens the cache has seen
- )
-
- # ----------------------------------------------------------#
- self.true_cache_length: List[torch.Tensor] = [] # L * [B]
- # ----------------------------------------------------------#
-
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
- """
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self):
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
- else:
- raise KeyError(
- f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
- )
-
- def __iter__(self):
- """
- Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
- keys and values
- """
- for layer_idx in range(len(self)):
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
-
- def __len__(self):
- """
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
- to the number of layers in the model.
- """
- return len(self.key_cache)
-
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- cache_decision: Optional[torch.Tensor] = None, # [B * N]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
-
- Return:
- A tuple containing the updated key and value states.
- """
- B, _, N, _ = key_states.shape
-
- # Update the number of seen tokens
- if layer_idx == 0:
- self.seen_tokens += key_states.shape[-2]
-
- if len(self.key_cache) <= layer_idx:
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
-
- # ----------------------------------------------------------#
- # for prefill stage
- if cache_decision is not None:
- self.true_cache_length.append(cache_decision.sum(dim=-1))
- else:
- self.true_cache_length.append(torch.tensor([N]).repeat(B))
- # ----------------------------------------------------------#
- else:
- # ----------------------------------------------------------#
- if cache_decision is not None:
- if B == 1 and N == 1:
- if cache_decision[0, 0]:
- self.key_cache[layer_idx] = torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- )
- self.value_cache[layer_idx] = torch.cat(
- [self.value_cache[layer_idx], value_states], dim=-2
- )
-
- self.true_cache_length[layer_idx] += N
- else:
- pass
- else: # TODO, efficiency needs to be optimized
- cur_layer_key_cache_batch_list = []
- cur_layer_value_cache_batch_list = []
- for b in range(B):
- cur_keep_indice = cache_decision[b]
- keep_key_states = key_states[
- b, :, cur_keep_indice, :
- ] # H * N * C
- keep_value_states = value_states[
- b, :, cur_keep_indice, :
- ] # H * N * C
- cur_layer_key_cache = torch.cat(
- [
- self.key_cache[layer_idx][
- b, :, : self.true_cache_length[layer_idx][b], :
- ],
- keep_key_states,
- ],
- dim=-2,
- )
- cur_layer_value_cache = torch.cat(
- [
- self.value_cache[layer_idx][
- b, :, : self.true_cache_length[layer_idx][b], :
- ],
- keep_value_states,
- ],
- dim=-2,
- )
- cur_layer_key_cache_batch_list.append(cur_layer_key_cache)
- cur_layer_value_cache_batch_list.append(cur_layer_value_cache)
-
- self.true_cache_length[layer_idx][b] += (
- cache_decision[b].sum().item()
- )
-
- max_cur_layer_kv_cache_length = max(
- cur_layer_key_cache_batch_list[b].shape[-2] for b in range(B)
- )
- for b in range(B):
- cur_len = cur_layer_key_cache_batch_list[b].shape[-2]
- cur_layer_key_cache_batch_list[b] = torch.cat(
- [
- cur_layer_key_cache_batch_list[b],
- torch.zeros(
- (
- cur_layer_key_cache_batch_list[b].shape[0],
- max_cur_layer_kv_cache_length - cur_len,
- cur_layer_key_cache_batch_list[b].shape[-1],
- ),
- dtype=cur_layer_key_cache_batch_list[b].dtype,
- device=cur_layer_key_cache_batch_list[b].device,
- ),
- ],
- dim=-2,
- )
- cur_layer_value_cache_batch_list[b] = torch.cat(
- [
- cur_layer_value_cache_batch_list[b],
- torch.zeros(
- (
- cur_layer_value_cache_batch_list[b].shape[0],
- max_cur_layer_kv_cache_length - cur_len,
- cur_layer_value_cache_batch_list[b].shape[-1],
- ),
- dtype=cur_layer_value_cache_batch_list[b].dtype,
- device=cur_layer_value_cache_batch_list[b].device,
- ),
- ],
- dim=-2,
- )
- self.key_cache[layer_idx] = torch.stack(
- cur_layer_key_cache_batch_list
- )
- self.value_cache[layer_idx] = torch.stack(
- cur_layer_value_cache_batch_list
- )
- else:
- self.key_cache[layer_idx] = torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- )
- self.value_cache[layer_idx] = torch.cat(
- [self.value_cache[layer_idx], value_states], dim=-2
- )
-
- self.true_cache_length[layer_idx] += N
- # ----------------------------------------------------------#
-
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
-
- # ----------------------------------------------------------#
- def get_cache(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if len(self.key_cache) <= layer_idx:
- return key_states, value_states
- else:
- return torch.cat(
- [self.key_cache[layer_idx], key_states], dim=-2
- ), torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
- # ----------------------------------------------------------#
-
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- if len(self.key_cache) <= layer_idx:
- return 0
- return self.key_cache[layer_idx].shape[-2]
-
- def get_max_length(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
- return None
-
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- for layer_idx in range(len(self.key_cache)):
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
- 0, beam_idx.to(device)
- )
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
- 0, beam_idx.to(device)
- )
-
- # ----------------------------------------------------------#
- def to_legacy_cache(
- self,
- ) -> Tuple[Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]], torch.Tensor]:
- """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
- legacy_cache = ()
- for layer_idx in range(len(self)):
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
- return (legacy_cache, self.true_cache_length)
-
- @classmethod
- def from_legacy_cache(
- cls,
- past_key_values: Optional[
- Tuple[Tuple[Tuple[torch.FloatTensor]], torch.Tensor]
- ] = None,
- ) -> "DynamicCachePlus":
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
- cache = cls()
- if past_key_values is not None:
- for layer_idx in range(len(past_key_values[0])):
- key_states, value_states = past_key_values[0][layer_idx]
- cache.update(key_states, value_states, layer_idx)
- cache.true_cache_length = past_key_values[1]
- return cache
-
- # ----------------------------------------------------------#
diff --git a/code/xtuner/model/fastv/fastv_qwen.py b/code/xtuner/model/fastv/fastv_qwen.py
deleted file mode 100644
index 0032c614911b9b21ce42f842611d4e2e586116eb..0000000000000000000000000000000000000000
--- a/code/xtuner/model/fastv/fastv_qwen.py
+++ /dev/null
@@ -1,1464 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The Qwen team, Alibaba Group
-# and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0
-#
-# This file provides a FastV-enabled Qwen2 implementation with:
-# - Config normalization (accepts HF PretrainedConfig, dict, or str path)
-# - Optional LLM wrapper instantiation via an `llm` dict (from_pretrained)
-# - FastV token pruning in the decoder forward pass
-# - Aliases Qwen25ModelFastV/Qwen25ForCausalLMFastV for config convenience
-
-import inspect
-import math
-import warnings
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from xtuner.registry import BUILDER
-
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache
-from transformers.modeling_attn_mask_utils import (
- _prepare_4d_causal_attention_mask,
- _prepare_4d_causal_attention_mask_for_sdpa,
-)
-from transformers.modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast,
-)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- is_flash_attn_greater_or_equal_2_10,
- logging,
- replace_return_docstrings,
-)
-
-try:
- from transformers.generation import GenerationMixin
-except Exception:
- from transformers.generation.utils import GenerationMixin
-
-
-from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
-from transformers import AutoConfig, PretrainedConfig, PreTrainedModel as HFPreTrainedModel
-
-if is_flash_attn_2_available():
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
-
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
-else:
- _flash_supports_window_size = False
-
-logger = logging.get_logger(__name__)
-
-_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
-_CONFIG_FOR_DOC = "Qwen2Config"
-
-QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
- "Qwen/Qwen2-7B-beta",
- # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
-]
-
-
-# Copied from transformers.models.llama.modeling_llama._get_unpad_data
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
-class Qwen2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Qwen2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
-class Qwen2RotaryEmbedding(nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
- )
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.rotate_half
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors."""
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
-class Qwen2MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
-
-# Copied from transformers.models.llama.modeling_llama.repeat_kv
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-class Qwen2Attention(nn.Module):
- """
- Standard attention implementation (eager).
- """
-
- def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
-
- self.rotary_emb = Qwen2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
- )
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
- "with a layer index."
- )
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
-
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
-
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Qwen2FlashAttention2(Qwen2Attention):
- """
- FlashAttention2-based attention.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ):
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
- )
- attention_mask = kwargs.pop("padding_mask")
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
- "with a layer index."
- )
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
-
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, "sliding_window", None) is not None
- and kv_seq_len > self.config.sliding_window
- and self.config.use_sliding_window
- )
-
- if past_key_value is not None:
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (
- getattr(self.config, "sliding_window", None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents
- ):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
- f" {past_key.shape}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
-
- cache_kwargs = {"sin": sin, "cos": cos}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate,
- use_sliding_windows=use_sliding_windows,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
- def _flash_attention_forward(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None,
- use_sliding_windows=False,
- ):
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- causal = self.is_causal and query_length != 1
-
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
- use_sliding_windows = False
-
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask, query_length
- )
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- if not use_sliding_windows:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window, self.config.sliding_window),
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
- else:
- if not use_sliding_windows:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window, self.config.sliding_window),
- )
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
-
- if kv_seq_len != attention_mask.shape[-1]:
- attention_mask_num_tokens = attention_mask.shape[-1]
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
-
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
-
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
-
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
- )
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- )
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-
-class Qwen2SdpaAttention(Qwen2Attention):
- """
- SDPA-based attention.
- """
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if output_attentions:
- logger.warning_once(
- "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
- "does not support `output_attentions=True`. Falling back to the manual attention implementation."
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {"sin": sin, "cos": cos}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
-
- if query_states.device.type == "cuda" and attention_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- dropout_p=self.attention_dropout if self.training else 0.0,
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-QWEN2_ATTENTION_CLASSES = {
- "eager": Qwen2Attention,
- "flash_attention_2": Qwen2FlashAttention2,
- "sdpa": Qwen2SdpaAttention,
-}
-
-
-class Qwen2DecoderLayer(nn.Module):
- def __init__(self, config: Qwen2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
- logger.warning_once(
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
- "unexpected results may be encountered."
- )
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
-
- self.mlp = Qwen2MLP(config)
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
- "Please make sure use `attention_mask` instead.`"
- )
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-QWEN2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-"""
-
-
-@add_start_docstrings(
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
- QWEN2_START_DOCSTRING,
-)
-class Qwen2PreTrainedModel(PreTrainedModel):
- config_class = Qwen2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Qwen2DecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-QWEN2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- ...
-"""
-
-
-@add_start_docstrings(
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
- QWEN2_START_DOCSTRING,
-)
-class Qwen2Model(Qwen2PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
-
- This version integrates FastV token pruning and supports an `llm` dict to
- instantiate a CausalLM wrapper for inference (`self.llm`).
- """
-
- def __init__(
- self,
- config: Union[PretrainedConfig, Dict, str],
- llm: Optional[Union[Dict, HFPreTrainedModel]] = None,
- freeze_llm: bool = True,
- pretrained_pth: Optional[str] = None,
- projector_depth: int = 2,
- llm_lora: Optional[Dict] = None,
- use_activation_checkpointing: bool = True,
- max_position_embeddings=None,
- hidden_size: int = 512,
- enable_long_net: bool = True,
- long_net_pth: Optional[str] = None,
- projector_pth: Optional[str] = None,
- image_feature_length: int = 196,
- ):
- # ---- Normalize config to a real HF PretrainedConfig ----
- if not isinstance(config, PretrainedConfig):
- if isinstance(config, dict):
- pretrained = config.pop('pretrained_model_name_or_path', None)
- trust_remote_code = bool(config.pop('trust_remote_code', True))
- if pretrained is not None:
- cfg = AutoConfig.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
- else:
- cfg = Qwen2Config()
- for k, v in config.items():
- setattr(cfg, k, v)
- config = cfg
- elif isinstance(config, str):
- config = AutoConfig.from_pretrained(config, trust_remote_code=True)
- else:
- raise TypeError("`config` must be a transformers.PretrainedConfig, dict, or str path.")
-
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.hidden_size = hidden_size
-
- # ---- FastV settings (read from config via getattr) ----
- self.use_fast_v = getattr(config, "use_fast_v", False)
- self.fast_v_sys_length = getattr(config, "fast_v_sys_length", 0)
- self.fast_v_image_token_length = getattr(config, "fast_v_image_token_length", 0)
- self.fast_v_attention_rank = getattr(config, "fast_v_attention_rank", 0)
- self.fast_v_agg_layer = getattr(config, "fast_v_agg_layer", 0)
- self.fast_v_inplace = getattr(config, "fast_v_inplace", True)
-
- # ---- Transformer backbone ----
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self._attn_implementation = config._attn_implementation
- self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- self.post_init()
-
- # Optional external modules (guarded; only load if attributes exist)
- _dtype = self.embed_tokens.weight.dtype
- if long_net_pth and hasattr(self, "LongNet_encoder"):
- from safetensors.torch import load_file
- logger.info(f"Loading LongNet from {long_net_pth}")
- self.LongNet_encoder.load_state_dict(load_file(long_net_pth, device="cpu"), strict=False)
- self.LongNet_encoder.to(_dtype)
-
- if projector_pth and hasattr(self, "projector"):
- from safetensors.torch import load_file
- logger.info(f"Loading projector from {projector_pth}")
- self.projector.load_state_dict(load_file(projector_pth, device="cpu"), strict=False)
- self.projector.to(_dtype)
-
- # ---- Build/attach LM wrapper if provided via dict or model ----
- self.llm: Optional[HFPreTrainedModel] = None
- if isinstance(llm, dict) and 'type' in llm:
- builder = llm['type'] # expected callable, e.g. Qwen2ForCausalLM.from_pretrained
- kwargs = {k: v for k, v in llm.items() if k != 'type'}
- self.llm = builder(**kwargs)
- if freeze_llm:
- for p in self.llm.parameters():
- p.requires_grad = False
- elif isinstance(llm, HFPreTrainedModel):
- self.llm = llm
- if freeze_llm:
- for p in self.llm.parameters():
- p.requires_grad = False
-
- # If an LM exists and supports FastV, push config now
- self.reset_fastv()
-
- # FastV helper to refresh from config and propagate to attached LLM
- def reset_fastv(self):
- self.use_fast_v = getattr(self.config, "use_fast_v", False)
- self.fast_v_sys_length = getattr(self.config, "fast_v_sys_length", 0)
- self.fast_v_image_token_length = getattr(self.config, "fast_v_image_token_length", 0)
- self.fast_v_attention_rank = getattr(self.config, "fast_v_attention_rank", 0)
- self.fast_v_agg_layer = getattr(self.config, "fast_v_agg_layer", 0)
- self.fast_v_inplace = getattr(self.config, "fast_v_inplace", True)
-
- if self.llm is not None and hasattr(self.llm, "set_fastv_config"):
- self.llm.set_fastv_config(
- dict(
- use_fastv=self.use_fast_v,
- fast_v_agg_layer=self.fast_v_agg_layer,
- fast_v_attention_rank=self.fast_v_attention_rank,
- fast_v_sys_length=self.fast_v_sys_length,
- fast_v_image_token_length=self.fast_v_image_token_length,
- fast_v_inplace=self.fast_v_inplace,
- )
- )
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- # Disable KV cache when FastV is active (unsupported by typical FastV pruning)
- if self.use_fast_v and use_cache:
- logger.warning_once("FastV active: disabling KV cache (use_cache=False).")
- use_cache = False
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- past_key_values_length = 0
-
- if use_cache:
- use_legacy_cache = not isinstance(past_key_values, Cache)
- if use_legacy_cache:
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- past_key_values_length = past_key_values.get_usable_length(seq_length)
- else:
- use_legacy_cache = not isinstance(past_key_values, Cache)
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
- )
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- # Prepare the base attention mask according to implementation
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right' "
- "which may misbehave for Flash Attention version of Qwen2. Use left padding."
- )
-
- if self._attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- elif self._attn_implementation == "sdpa" and not output_attentions:
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- )
- else:
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- sliding_window=self.config.sliding_window,
- )
-
- hidden_states = inputs_embeds
-
- # ---- FastV runtime state ----
- prev_layer_attn: Optional[torch.Tensor] = None # [B, H, q_len, k_len]
- pruned_keep_indices: Optional[torch.Tensor] = None # [new_seq_len]
- gen_attention_mask: Optional[torch.Tensor] = None # prepared 4D mask (non-inplace) after pruning
- current_seq_length = seq_length
- # -----------------------------
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
-
- for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- # Decide whether to request attentions from this layer.
- need_attn_this_layer = output_attentions or (self.use_fast_v and idx == (self.fast_v_agg_layer - 1))
-
- # If FastV is active and we have reached/prior to agg layer, choose the mask
- if self.use_fast_v:
- # By default, use the baseline attention mask
- attn_mask_for_layer = attention_mask
-
- if idx < self.fast_v_agg_layer:
- pass
-
- elif idx == self.fast_v_agg_layer:
- # Apply pruning using the *previous* layer's attentions if available
- if prev_layer_attn is not None:
- # average over heads; take last generated token's attention (q = last index)
- # assume batch size 1 for SlideChat inference
- attn_avg = prev_layer_attn.mean(dim=1)[0] # [q_len, k_len]
- last_tok_attn = attn_avg[-1] # [k_len]
- sys_len = int(self.fast_v_sys_length)
- img_len = int(self.fast_v_image_token_length)
- if img_len > 0 and sys_len + img_len <= last_tok_attn.shape[0]:
- image_slice = last_tok_attn[sys_len : sys_len + img_len]
- keep_k = max(1, int(self.fast_v_attention_rank))
- keep_k = min(keep_k, img_len)
- top_idx = torch.topk(image_slice, k=keep_k, largest=True).indices.to(hidden_states.device)
- top_idx = top_idx + sys_len # offset into full sequence
-
- # build keep indices: [0:sys_len) + top image indices + [sys_len+img_len : end)
- tail_start = sys_len + img_len
- keep_indices = torch.cat(
- (
- torch.arange(sys_len, device=hidden_states.device),
- top_idx,
- torch.arange(tail_start, current_seq_length, device=hidden_states.device),
- ),
- dim=0,
- ).sort().values
- else:
- # Fallback: do not prune if sizes are inconsistent
- keep_indices = torch.arange(current_seq_length, device=hidden_states.device)
- else:
- keep_indices = torch.arange(current_seq_length, device=hidden_states.device)
-
- pruned_keep_indices = keep_indices
-
- if self.fast_v_inplace:
- # In-place: shrink the sequence, update positions and rebuild mask for the *pruned* length
- hidden_states = hidden_states[:, pruned_keep_indices, :]
- position_ids = pruned_keep_indices.unsqueeze(0)
- current_seq_length = pruned_keep_indices.shape[0]
-
- if self._attn_implementation == "flash_attention_2":
- attn_mask_for_layer = None if attention_mask is None else attention_mask
- elif self._attn_implementation == "sdpa" and not output_attentions:
- attn_mask_for_layer = _prepare_4d_causal_attention_mask_for_sdpa(
- None, # start from a clean mask
- (batch_size, current_seq_length),
- hidden_states,
- 0,
- )
- else:
- attn_mask_for_layer = _prepare_4d_causal_attention_mask(
- None,
- (batch_size, current_seq_length),
- hidden_states,
- 0,
- sliding_window=self.config.sliding_window,
- )
- attention_mask = attn_mask_for_layer
- else:
- # Mask-only pruning: keep the full sequence but mask out dropped image tokens
- base = torch.ones((batch_size, current_seq_length), dtype=torch.bool, device=hidden_states.device)
- sys_len = int(self.fast_v_sys_length)
- img_len = int(self.fast_v_image_token_length)
- tail_start = sys_len + img_len
- if img_len > 0 and sys_len + img_len <= current_seq_length:
- base[:, sys_len:sys_len+img_len] = False
- base[:, pruned_keep_indices[(pruned_keep_indices >= sys_len) & (pruned_keep_indices < tail_start)]] = True
- if self._attn_implementation == "sdpa" and not output_attentions:
- attn_mask_for_layer = _prepare_4d_causal_attention_mask_for_sdpa(
- base, (batch_size, current_seq_length), hidden_states, 0
- )
- elif self._attn_implementation == "flash_attention_2":
- attn_mask_for_layer = base
- else:
- attn_mask_for_layer = _prepare_4d_causal_attention_mask(
- base, (batch_size, current_seq_length), hidden_states, 0, sliding_window=self.config.sliding_window
- )
- gen_attention_mask = attn_mask_for_layer
- attention_mask = gen_attention_mask # persist for following layers
- else:
- if gen_attention_mask is not None:
- attn_mask_for_layer = gen_attention_mask
- else:
- attn_mask_for_layer = attention_mask
- else:
- attn_mask_for_layer = attention_mask
-
- # Run the layer
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- attn_mask_for_layer,
- position_ids,
- past_key_values,
- need_attn_this_layer,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attn_mask_for_layer,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=need_attn_this_layer,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[2 if need_attn_this_layer else 1]
-
- # Collect attention for FastV (to be used on next layer)
- if need_attn_this_layer:
- prev_layer_attn = layer_outputs[1]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = None
- if use_cache:
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
-
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
-@BUILDER.register_module()
-class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
-
- def __init__(self, config):
- super().__init__(config)
- self.model = Qwen2Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.post_init()
-
- # Convenience helper to set FastV config at runtime
- def set_fastv_config(self, fastv_config: dict):
- """
- fastv_config keys:
- use_fastv / use_fast_v (bool)
- fastv_k / fast_v_agg_layer (int)
- fastv_r (float in (0,1]) or fast_v_attention_rank (int)
- image_token_start_index / fast_v_sys_length (int)
- image_token_length / fast_v_image_token_length (int)
- fast_v_inplace (bool, default True)
- """
- use_fast = fastv_config.get("use_fastv", fastv_config.get("use_fast_v", False))
- self.config.use_fast_v = bool(use_fast)
-
- if "fastv_k" in fastv_config:
- self.config.fast_v_agg_layer = int(fastv_config["fastv_k"])
- elif "fast_v_agg_layer" in fastv_config:
- self.config.fast_v_agg_layer = int(fastv_config["fast_v_agg_layer"])
-
- if "fastv_r" in fastv_config and "fast_v_image_token_length" in self.config.__dict__:
- img_len = int(getattr(self.config, "fast_v_image_token_length", 0))
- self.config.fast_v_attention_rank = max(1, int(round(float(fastv_config["fastv_r"]) * max(1, img_len))))
- elif "fastv_r" in fastv_config and "image_token_length" in fastv_config:
- img_len = int(fastv_config["image_token_length"])
- self.config.fast_v_attention_rank = max(1, int(round(float(fastv_config["fastv_r"]) * max(1, img_len))))
- elif "fast_v_attention_rank" in fastv_config:
- self.config.fast_v_attention_rank = int(fastv_config["fast_v_attention_rank"])
-
- if "image_token_start_index" in fastv_config:
- self.config.fast_v_sys_length = int(fastv_config["image_token_start_index"])
- elif "fast_v_sys_length" in fastv_config:
- self.config.fast_v_sys_length = int(fastv_config["fast_v_sys_length"])
-
- if "image_token_length" in fastv_config:
- self.config.fast_v_image_token_length = int(fastv_config["image_token_length"])
- elif "fast_v_image_token_length" in fastv_config:
- self.config.fast_v_image_token_length = int(fastv_config["fast_v_image_token_length"])
-
- if "fast_v_inplace" in fastv_config:
- self.config.fast_v_inplace = bool(fastv_config["fast_v_inplace"])
-
- # Push down to the underlying model
- self.model.reset_fastv()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
-
- >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
- ):
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- cache_length = past_key_values.get_seq_length()
- past_length = past_key_values.seen_tokens
- max_cache_length = past_key_values.get_max_length()
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
-
- if (
- max_cache_length is not None
- and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length
- ):
- attention_mask = attention_mask[:, -max_cache_length:]
-
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values is not None:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
- )
- return reordered_past
-
-
-@add_start_docstrings(
- """
- The Qwen2 Model transformer with a sequence classification head on top (linear layer).
- """,
- QWEN2_START_DOCSTRING,
-)
-class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = Qwen2Model(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
-
-# -------------- aliases for convenience in configs --------------
-# Qwen25ModelFastV = Qwen2Model
-# Qwen25ForCausalLMFastV = Qwen2ForCausalLM
diff --git a/code/xtuner/model/internvl.py b/code/xtuner/model/internvl.py
deleted file mode 100644
index 0358266a9ff40defc650ca62179a1c496653bed7..0000000000000000000000000000000000000000
--- a/code/xtuner/model/internvl.py
+++ /dev/null
@@ -1,320 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from collections import OrderedDict
-from typing import List, Optional, Tuple, Union
-
-import torch
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from torch.nn import CrossEntropyLoss
-from transformers import (AutoConfig, AutoModel, AutoTokenizer,
- BitsAndBytesConfig)
-from transformers.modeling_outputs import CausalLMOutputWithPast
-
-from xtuner.registry import BUILDER
-from .utils import (find_all_linear_names, get_peft_model_state_dict,
- guess_load_checkpoint, make_inputs_require_grad)
-
-
-class InternVL_V1_5(BaseModel):
-
- def __init__(self,
- model_path,
- freeze_llm=False,
- freeze_visual_encoder=False,
- llm_lora=None,
- visual_encoder_lora=None,
- quantization_vit=False,
- quantization_llm=False,
- pretrained_pth=None):
- print_log('Start to load InternVL_V1_5 model.', logger='current')
- super().__init__()
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = freeze_visual_encoder
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = visual_encoder_lora is not None
- self.quantization_vit = quantization_vit
- self.quantization_llm = quantization_llm
- if quantization_vit:
- assert visual_encoder_lora is not None
- if quantization_llm:
- assert quantization_llm and llm_lora is not None
-
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
- if config.llm_config.model_type == 'internlm2':
- config.llm_config.attn_implementation = 'flash_attention_2'
- else:
- config.llm_config._attn_implementation = 'flash_attention_2'
-
- if quantization_vit is False and quantization_llm is False:
- quantization = None
- else:
- llm_int8_skip_modules = ['mlp1']
- if quantization_llm and not quantization_vit:
- llm_int8_skip_modules.append('vision_model')
-
- if quantization_vit and not quantization_llm:
- llm_int8_skip_modules.append('language_model')
-
- quantization_config = dict(
- type=BitsAndBytesConfig,
- llm_int8_skip_modules=llm_int8_skip_modules,
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4')
- quantization_clazz = quantization_config.pop('type')
- quantization = quantization_clazz(**quantization_config)
-
- self.model = AutoModel.from_pretrained(
- model_path,
- torch_dtype=torch.bfloat16,
- quantization_config=quantization,
- config=config,
- trust_remote_code=True)
-
- tokenizer = AutoTokenizer.from_pretrained(
- model_path, trust_remote_code=True)
- img_context_token_id = tokenizer.convert_tokens_to_ids('')
- self.model.img_context_token_id = img_context_token_id
-
- if self.freeze_llm:
- self.model.language_model.requires_grad_(False)
- if self.freeze_visual_encoder:
- self.model.vision_model.requires_grad_(False)
-
- if hasattr(self.model.language_model, 'enable_input_require_grads'):
- self.model.language_model.enable_input_require_grads()
- else:
- self.model.language_model.get_input_embeddings(
- ).register_forward_hook(make_inputs_require_grad)
-
- self.gradient_checkpointing_enable()
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora)
-
- if self.use_visual_encoder_lora:
- self._prepare_visual_encoder_for_lora(visual_encoder_lora)
-
- if pretrained_pth is not None:
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
-
- self.load_state_dict(pretrained_state_dict, strict=False)
- print(f'Load pretrained weight from {pretrained_pth}')
-
- self._count = 0
- print_log(self, logger='current')
- print_log('InternVL_V1_5 construction is complete', logger='current')
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.model.language_model = prepare_model_for_kbit_training(
- self.model.language_model, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.model.language_model)
- lora_config.target_modules = modules
- self.model.language_model = get_peft_model(self.model.language_model,
- lora_config)
-
- def _prepare_visual_encoder_for_lora(self, lora_config):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.model.vision_model)
- lora_config.target_modules = modules
- self.model.vision_model = get_peft_model(self.model.vision_model,
- lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.model.language_model.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.model.language_model.gradient_checkpointing_disable()
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.model.vision_model, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'model.vision_model.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.model.language_model, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'model.language_model.' in k
- })
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'model.mlp1.' in k})
- return to_return
-
- def init_weights(self):
- pass
-
- def forward(self, data, data_samples=None, mode='loss'):
- pixel_values = data['pixel_values']
-
- if type(pixel_values) is list or pixel_values.ndim == 5:
- if type(pixel_values) is list:
- pixel_values = [
- x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
- ]
- # b*n, c, h, w
- concat_images = torch.cat([
- image.to(self.model.vision_model.dtype)
- for image in pixel_values
- ],
- dim=0)
- else:
- raise NotImplementedError()
-
- input_ids = data['input_ids']
- position_ids = data['position_ids']
- attention_mask = data['attention_mask']
- # sum is 0 are text
- image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0
- image_flags = image_flags.long()
-
- labels = data['labels']
- use_cache = False
-
- # Directly calling this code in LORA fine-tuning
- # will result in an error,so we must rewrite it.
- # TODO: Once the official is fixed, we can remove it.
- # outputs = self.model(input_ids=input_ids,
- # position_ids=position_ids,
- # attention_mask=attention_mask,
- # image_flags=image_flags,
- # pixel_values=concat_images,
- # labels=labels,
- # use_cache=use_cache)
- outputs = self._llm_forward(
- input_ids=input_ids,
- position_ids=position_ids,
- attention_mask=attention_mask,
- image_flags=image_flags,
- pixel_values=concat_images,
- labels=labels,
- use_cache=use_cache)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def _llm_forward(
- self,
- pixel_values: torch.FloatTensor,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- image_flags: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- return_dict = return_dict if return_dict is not None \
- else self.model.config.use_return_dict
-
- image_flags = image_flags.squeeze(-1)
- # We only added the clone code here to avoid the error.
- input_embeds = self.model.language_model.get_input_embeddings()(
- input_ids).clone()
-
- vit_embeds = self.model.extract_feature(pixel_values)
- vit_embeds = vit_embeds[image_flags == 1]
- vit_batch_size = pixel_values.shape[0]
-
- B, N, C = input_embeds.shape
- input_embeds = input_embeds.reshape(B * N, C)
-
- if torch.distributed.get_rank() == 0 and self._count % 100 == 0:
- print(f'dynamic ViT batch size: {vit_batch_size}, '
- f'images per sample: {vit_batch_size / B}, '
- f'dynamic token length: {N}')
- self._count += 1
-
- input_ids = input_ids.reshape(B * N)
- selected = (input_ids == self.model.img_context_token_id)
- try:
- input_embeds[
- selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(
- -1, C)
- except Exception as e:
- vit_embeds = vit_embeds.reshape(-1, C)
- print(f'warning: {e}, input_embeds[selected].shape='
- f'{input_embeds[selected].shape}, '
- f'vit_embeds.shape={vit_embeds.shape}')
- n_token = selected.sum()
- input_embeds[
- selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
-
- input_embeds = input_embeds.reshape(B, N, C)
-
- outputs = self.model.language_model(
- inputs_embeds=input_embeds,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = outputs.logits
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(
- -1, self.model.language_model.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits, ) + outputs[1:]
- return (loss, ) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
diff --git a/code/xtuner/model/llava.py b/code/xtuner/model/llava.py
deleted file mode 100644
index 144acae17f9ce9ab3a8e028031ba12109bf8ae5f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava.py
+++ /dev/null
@@ -1,1009 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.distributed as dist # === MOD ===
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-import os
-from safetensors.torch import load_file, save_file
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNetWithMerging import make_swin_longnet_from_name
-# from xtuner.model.torchscale.model.LongNet import make_longnet_from_name
-from .torchscale.model.LongNet import make_longnet_from_name
-import torch.nn.functional as F
-
-# ===== 在类前或类内其它位置都可以:新增一个探测函数 =====
-def _detect_qwen_major_version(llm) -> int:
- """
- 返回 3 表示 Qwen3,2 表示 Qwen2,0 表示未知/其它。
- 优先用 config.model_type,其次回退到类名字符串。
- """
- base = llm.model if hasattr(llm, "model") else llm
- cfg = getattr(base, "config", None)
- mt = (getattr(cfg, "model_type", None) or "").lower()
- if mt == "qwen3":
- return 3
- if mt == "qwen2":
- return 2
-
- # 回退:根据类名判别
- cname = base.__class__.__name__.lower()
- if "qwen3" in cname:
- return 3
- if "qwen2" in cname:
- return 2
- return 0
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- long_net_pth=None,
- projector_pth = None,
- perceiver_pth = None,
- #config swin_longnet
- use_swin_longnet = True,
- add_abs_pe_to_longnet_inputs = True,
-
- longnet_pe_gate_ratio = 0.1,
- longnet_pe_dropout_rate = 0.1,
-
- fourier_dims = 32,
-
- # config for Perceiver Resampler
- use_perceiver_resampler = True,
- perceiver_num_latents=64,
- perceiver_depth=2,
- perceiver_fourier_dims = 32,
- perceiver_pe_gate_ratio = 0.1,
- perceiver_pe_dropout_rate = 0.1
- ):
- super().__init__()
-
- self.enable_long_net = enable_long_net
-
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
-
- self.use_swin_longnet = use_swin_longnet
-
- if train_stage == '0':
- print_log('train_stage == 0', 'current')
- self.freeze_llm = True
- self.freeze_long_net = True
-
- if train_stage == '1':
- print_log('train_stage == 1', 'current')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print_log('train_stage == 2', 'current')
- self.freeze_llm = False #False
- self.freeze_long_net = False #False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
-
- if not self.use_swin_longnet:
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name,
- enable_gradient_checkpoint= False) # , drop_path_rate=0.3, dropout=0.3, segment_length=1024
- else:
- print('use swin long net')
- from .coords_pe import Coord2Embed
- self.add_abs_pe_to_longnet_inputs = add_abs_pe_to_longnet_inputs
- self.coord2embed_longnet = Coord2Embed(out_dim=hidden_size,
- fourier_dims=fourier_dims).to(dtype=self.llm.dtype) # hidden_size == LongNet input C
- self.longnet_pe_gate = nn.Parameter(torch.tensor(longnet_pe_gate_ratio,
- dtype=self.llm.dtype))
-
- self.longnet_pe_dropout = nn.Dropout(p = longnet_pe_dropout_rate) # optional
-
- self.LongNet_encoder = make_swin_longnet_from_name(self.encoder_name,
- keep_dim_after_merge= True,
- merge_size = 2,
- use_rel_pos_2d= False,
- enable_gradient_checkpoint= False
- )
- self.LongNet_encoder = self.LongNet_encoder.to(self.llm.dtype)
-
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- self.use_perceiver_resampler = use_perceiver_resampler
- if self.use_perceiver_resampler:
- # ---- 自动选择 Qwen3 或 Qwen2 的 Perceiver 实现 ----
- self.perceiver_num_latents = perceiver_num_latents
- self.perceiver_depth = perceiver_depth
-
- from .coords_pe import Coord2Embed
- self.key_pos_enc = Coord2Embed(out_dim=self.hidden_size,
- fourier_dims=perceiver_fourier_dims).to(dtype=self.llm.dtype)
- self.key_pos_gate = nn.Parameter(torch.tensor(perceiver_pe_gate_ratio
- , dtype =self.llm.dtype))
- self.key_pos_dropout = nn.Dropout(p = perceiver_pe_dropout_rate)
-
- qwen_major = _detect_qwen_major_version(self.llm)
- print_log(f'using qwen version{qwen_major}', 'current')
- if qwen_major == 3:
- # Qwen3 分支
- try:
- from .qwen3_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- print_log('using qwen3', 'current')
- except Exception as e:
- raise RuntimeError(
- "检测到 Qwen3,但未找到 qwen3_perceiver_resampler,请确认文件存在且 transformers 版本满足要求(>=4.51)。"
- ) from e
- elif qwen_major == 2:
- # Qwen2 分支
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- else:
- warnings.warn(
- "未能确定 Qwen 主版本(既不是 qwen3 也不是 qwen2)。将回退到 Qwen2 的 Perceiver 实现。",
- RuntimeWarning,
- )
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
-
- # 构建并初始化 Perceiver
- self.perceiver = _PR(
- self.llm,
- num_latents=self.perceiver_num_latents,
- depth=self.perceiver_depth,
- ).to(self.llm.dtype)
-
- _init_pr(
- perceiver=self.perceiver,
- llm=self.llm,
- ckpt_hint=getattr(self.llm.config, "_name_or_path", None),
- init_from_layers=self.perceiver.depth,
- layer_offset=0,
- allow_download=False,
- )
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
- if self.use_perceiver_resampler:
- self.perceiver.enable_input_require_grads()
-
- self.projector.enable_input_require_grads()
-
- # self.LongNet_encoder.enable_input_require_grads()
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = None
- if self.use_llm_lora:
- print_log(f"Building lora {llm_lora.__str__}", "current")
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- # ── 2) Load projector + LongNet from safetensors ────────────────────────
- if long_net_pth is not None:
- print_log(f"Loading LongNet from {long_net_pth}", "current")
- ln_sd = load_file(long_net_pth, device="cpu")
- self.LongNet_encoder.load_state_dict(ln_sd, strict=False)
- self.LongNet_encoder.to(self.llm.dtype)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if perceiver_pth is not None and self.use_perceiver_resampler:
- print_log(f'Loading perceiver from {perceiver_pth}", "current ')
- perceiver_sd = load_file(perceiver_pth, device="cpu")
- self.projector.load_state_dict(perceiver_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- # ── 3) Optionally load a classic float checkpoint and filter mismatches ──
- if pretrained_pth is not None:
- sd = guess_load_checkpoint(pretrained_pth)
- model_sd = self.state_dict()
- filtered = {
- k: v for k, v in sd.items()
- if k in model_sd and model_sd[k].shape == v.shape
- }
- missing, unexpected = self.load_state_dict(filtered, strict=False)
- print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
- print_log(f" missing: {missing}", "current")
- print_log(f" unexpected:{unexpected}", "current")
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
-
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- if self.use_perceiver_resampler:
- self.perceiver.enable_input_require_grads()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- if self.use_perceiver_resampler:
- self.perceiver.disable_gradient_checkpointing()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
-
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
-
- # Step 5. Perceiver Resampler (unchanged)
- if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None:
- to_return.update({k: v for k, v in state_dict.items() if 'perceiver.' in k})
-
- # Step 6. NEW — Positional encoders & gates
- # 6a) LongNet input PE
- if hasattr(self, 'coord2embed_longnet'):
- to_return.update({k: v for k, v in state_dict.items() if 'coord2embed_longnet.' in k})
-
- if 'longnet_pe_gate' in state_dict:
- to_return['longnet_pe_gate'] = state_dict['longnet_pe_gate']
-
- # 6b) Perceiver key-side PE (only if you added it)
- if hasattr(self, 'key_pos_enc'):
- to_return.update({k: v for k, v in state_dict.items() if 'key_pos_enc.' in k})
- if 'key_pos_gate' in state_dict:
- to_return['key_pos_gate'] = state_dict['key_pos_gate']
-
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
-
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- coords_v = None
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
- if self.enable_long_net:
- if not self.use_swin_longnet:
- # long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"] # shape: (img_num, 1, 1024)
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj)["encoder_out"]
- elif self.add_abs_pe_to_longnet_inputs and 'coords' in data:
- # coords_shape = data['coords'].shape
- # print_log(f'using swin long net with coords shape {coords_shape}', 'current')
- pe = self.coord2embed_longnet(data['coords'].to(feat_to_proj.dtype)).to(feat_to_proj.dtype)
- feat_to_proj = feat_to_proj + self.longnet_pe_dropout(self.longnet_pe_gate * pe)
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj,
- coords=data['coords'].to(self.llm.dtype))
- long_net_output, coords_v = long_net_output["encoder_out"], long_net_output['coords']
- elif 'coords' in data:
- # feat_to_proj = feat_to_proj + self.longnet_pe_dropout(self.longnet_pe_gate * pe)
- long_net_output = self.LongNet_encoder(src_tokens=None,
- token_embeddings=feat_to_proj,
- coords=data['coords'].to(self.llm.dtype))
- long_net_output, coords_v = long_net_output["encoder_out"], long_net_output['coords']
- else:
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj)["encoder_out"]
- # wl - output shape (img_num, 1, 512)
- feat_to_proj = long_net_output # permuted shape: [1, img_num, 512]
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584]
- if self.use_perceiver_resampler and 'input_ids' in data:
-
- # do this only here to void copy embedding layer to preceiver
- text_embeddings = self.llm.get_input_embeddings()(
- data["input_ids"].clamp(min=0)
- ).to(self.llm.dtype).detach()
- if coords_v is not None:
- kpe = self.key_pos_enc(coords_v.to(pixel_values.device)).to(pixel_values.dtype)
- pixel_values = pixel_values + self.key_pos_dropout(self.key_pos_gate * kpe)
- compressed = self.perceiver(
- # input_ids = data["input_ids"],
- text_embeddings=text_embeddings,
- attention_mask=data.get("attention_mask", None),
- visual_tokens=pixel_values,
- )
- data["pixel_values"] = compressed
- else:
- data['pixel_values'] = pixel_values # shape: [1, patch_num, 3584] # shape: [1, 576, 4096]
-
- # remove coords
- data.pop('coords', None)
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- # def compute_loss(self, data, data_samples=None):
- # outputs = self.llm(**data)
- # # outputs.logits.shape (1, 1094, 152064) for Qwen
- # loss_dict = {'loss': outputs.loss}
- # return loss_dict
- # === MOD: token-averaged, globally weighted loss (robust to variable lengths)
-
-
- def compute_loss(self, data, data_samples=None):
- # 1) 若无 labels,退回 HF 默认
- if 'labels' not in data:
- outputs = self.llm(**data)
- return {'loss': outputs.loss}
-
- labels = data['labels'] # [B, T]
- # 不把 labels 交给 HF,避免其先做 per-device mean
- model_inputs = {k: v for k, v in data.items() if k != 'labels'}
-
- outputs = self.llm(**model_inputs, use_cache=False)
- logits = outputs.logits # [B, T, V]
-
- # 2) CausalLM 对齐
- shift_logits = logits[:, :-1, :].contiguous()
- shift_labels = labels[:, 1:].contiguous()
-
- # 3) 本卡有效 token 数(忽略 -100)
- n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- # 4) 分子:sum over tokens(用 FP32 计算更稳)
- loss_sum_local = F.cross_entropy(
- shift_logits.float().view(-1, shift_logits.size(-1)),
- shift_labels.view(-1),
- ignore_index=-100,
- reduction='sum'
- )
-
- # 5) 计算全局分母;不要让反传穿过 collective(用 no_grad + clone)
- world_size = 1
- n_tok_global = n_tok_local
- if dist.is_available() and dist.is_initialized():
- world_size = dist.get_world_size()
- with torch.no_grad():
- n_tok_global = n_tok_local.clone()
- dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
-
- # 6) 构造最终 loss:
- # 用“本卡分子 / 全局分母”,再乘 world_size 抵消 DDP 的梯度平均,
- # 这样反向后的等效梯度就是“全局 token 平均”的梯度。
- loss = (loss_sum_local / denom) * float(world_size)
-
- # 7) 记录指标:把 ntok 作为张量返回,避免 parse_losses 报错
- ntok_tensor = denom.detach() # float 标量张量即可
-
- return {
- 'loss': loss,
- 'ntok': ntok_tensor
- }
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
- if self.use_perceiver_resampler:
-
- perceiver_path = osp.join(save_dir, "perceiver")
- print_log(f'Saving LongNet_encoder to {perceiver_path}', 'current')
- os.makedirs(perceiver_path, exist_ok=True)
- perceiver_output_path = os.path.join(perceiver_path, 'perceiver.safetensors')
- save_file(self.perceiver.state_dict(), perceiver_output_path)
-
- # LongNet_encoder
- if self.LongNet_encoder is not None:
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- # Ensure the target directory exists
- os.makedirs(LongNet_encoder_path, exist_ok=True)
-
- # Define the full path for the weights file
- output_path = osp.join(LongNet_encoder_path, 'longnet_encoder.safetensors')
-
- # Save the state dictionary using safetensors
- save_file(self.LongNet_encoder.state_dict(), output_path)
-
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_acmil.py b/code/xtuner/model/llava_acmil.py
deleted file mode 100644
index 3c6e1e7c6dcb96e6be0a3fef211aa2021da6ad93..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_acmil.py
+++ /dev/null
@@ -1,755 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-import torch.nn.functional as F
-from architecture.transformer import ACMIL_GA_NoClassifier, ACMIL_MHA_NoClassifier
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-class Struct:
- def __init__(self, **entries):
- self.__dict__.update(entries)
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
-
- enable_long_net=True,
- acmil_type = 'ga',
- acmil_tokens = 2000):
- super().__init__()
-
- self.enable_long_net = enable_long_net
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = True #False
- self.freeze_long_net = True #False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name) # , drop_path_rate=0.3, dropout=0.3, segment_length=1024
-
- self.llm.config.use_cache = False
-
- dispatch_modules(self.llm)
-
- # self.LongNet_encoder = self.LongNet_encoder.to(self.llm.dtype)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- acmil_config = Struct(
- D_feat = hidden_size,
- D_inner = hidden_size,
- n_token = 1,
- mask_drop=0.6,
- n_masked_patch=acmil_tokens,
- )
-
-
- if acmil_type == 'ga':
- print('use ACMIL_GA_NoClassifier')
-
- self.acmil = ACMIL_GA_NoClassifier(
- conf = acmil_config,
- n_masked_patch= acmil_config.n_masked_patch,
- n_token= acmil_config.n_token,
- mask_drop= acmil_config.mask_drop,
- )
- elif acmil_type == 'mha':
- print('use ACMIL_MHA_NoClassifier')
-
- self.acmil = ACMIL_MHA_NoClassifier(
- conf = acmil_config,
- n_masked_patch= acmil_config.n_masked_patch,
- n_token= acmil_config.n_token,
- mask_drop= acmil_config.mask_drop,
-
- )
- else:
- raise NotImplementedError(f'Unknown acmil_type: {acmil_type}')
-
-
- self.acmil.to(self.llm.dtype)
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
- if self.freeze_long_net:
- print('remove LongNet_encoder')
- self.LongNet_encoder.requires_grad_(False)
-
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
-
- self.projector.enable_input_require_grads()
- # self.LongNet_encoder.enable_input_require_grads()
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if pretrained_pth is not None: # load the pretrained checkpoint
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
-
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}',
- 'current')
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
-
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
-
- to_return.update({k: v for k, v in state_dict.items() if 'acmil.' in k})
-
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
- _, feat_to_proj, _ = self.acmil(feat_to_proj) # feat_to_proj shape: [1, 1, 512]
- # if self.enable_long_net:
- # long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"] # shape: (img_num, 1, 1024)
- # # wl - output shape (img_num, 1, 512)
- # feat_to_proj = long_net_output.permute(1, 0, 2) # permuted shape: [1, img_num, 512]
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584]
-
- data['pixel_values'] = pixel_values.unsqueeze(0) # shape: [1, patch_num, 3584] # shape: [1, 576, 4096]
- # print_log(f'pixel_values shape: {data["pixel_values"].shape}', 'current')
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- # outputs.logits.shape (1, 1094, 152064) for Qwen
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- self.projector.save_pretrained(projector_path,
- **save_pretrained_kwargs)
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- self.LongNet_encoder.save_pretrained(LongNet_encoder_path,
- **save_pretrained_kwargs)
-
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- # LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- # LongNet_encoder_state_dict = convert_state_dict_to_hf(
- # LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- # **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- # LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- # LongNet_encoder_state_dict = convert_state_dict_to_hf(
- # LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- # **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_attn.py b/code/xtuner/model/llava_attn.py
deleted file mode 100644
index cc4ec9be3fd37cedb0a999086b6f2b50d3b7def4..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_attn.py
+++ /dev/null
@@ -1,744 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-import torch.nn.functional as F
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-class LLaVAModel_Attn(BaseModel):
-
- def __init__(self,
- llm,
- visual_encoder,
- freeze_llm=True,
- freeze_visual_encoder=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True):
- super().__init__()
- self.enable_long_net = enable_long_net
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = freeze_visual_encoder
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = False
- self.freeze_long_net = False
-
- with LoadWoInit():
- # if isinstance(llm, dict):
- # llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
- self.visual_encoder = self._build_from_cfg_or_module(
- visual_encoder)
-
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name) # , drop_path_rate=0.3, dropout=0.3, segment_length=1024
-
- self.llm.config.use_cache = False
- # dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.visual_encoder.dtype)
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
- if self.freeze_visual_encoder:
- print('freeze_visual_encoder')
- self.visual_encoder.requires_grad_(False)
-
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
- # if hasattr(self.visual_encoder, 'enable_input_require_grads'):
- # self.visual_encoder.enable_input_require_grads()
- # else:
- # self.visual_encoder.get_input_embeddings(
- # ).register_forward_hook(make_inputs_require_grad)
-
- self.projector.enable_input_require_grads()
- # self.LongNet_encoder.enable_input_require_grads()
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
- # if self.use_visual_encoder_lora:
- # self._prepare_visual_encoder_for_lora(
- # visual_encoder_lora, use_activation_checkpointing)
-
- if pretrained_pth is not None:
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
-
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}',
- 'current')
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
-
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- if 'pixel_values' in data:
-
- # multi-image
- # # data['pixel_values'].shape: (img_num, 3, 336, 336)
- # visual_outputs = self.visual_encoder(
- # data['pixel_values'].to(self.visual_encoder.dtype),
- # output_hidden_states=True)
-
- # # visual_outputs.hidden_states[self.visual_select_layer][:, 1:].shape: (img_num, 576, 1024)
-
- # long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=visual_outputs.hidden_states[self.visual_select_layer][:, 1:].permute(1, 0, 2))["encoder_out"] # shape: (576, img_num, 1024)
- # feat_to_proj = long_net_output.permute(1, 0, 2) # shape: (img_num, 576, 1024)
- # feat_to_proj = torch.mean(feat_to_proj, dim=0, keepdim=True) # shape: (1, 576, 1024)
-
- # wsi
- # chunks = torch.chunk(data['pixel_values'], chunks=5000, dim=1)
- # # 对每个 chunk 进行处理
- # processed_chunks = [torch.mean(chunk, dim=1, keepdim=True) for chunk in chunks]
- # # 合并处理后的 chunks
- # merged_tensor = torch.cat(processed_chunks, dim=1)
- # feat_to_proj = merged_tensor.to(self.visual_encoder.dtype)
-
- # print(data['pixel_values'].shape)
- # print('------------------------------self.visual_encoder.dtype--------------------------------------', self.visual_encoder.dtype)
- feat_to_proj = data['pixel_values'].to(self.visual_encoder.dtype) # torch.Size([1, img_num, 768])
- if self.enable_long_net:
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"] # shape: (576, img_num, 1024)
- feat_to_proj = long_net_output.permute(1, 0, 2) # shape: [1, img_num, 768]
-
- # print('to projector')
- # print('*'*30)555pp0--00[p\]
- # print('feat_to_proj')
- # print(feat_to_proj.shape)
-
- pixel_values = self.projector(feat_to_proj.to(self.visual_encoder.dtype))
-
-
-
- data['pixel_values'] = pixel_values # shape: [1, 576, 4096]
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- self.projector.save_pretrained(projector_path,
- **save_pretrained_kwargs)
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- self.LongNet_encoder.save_pretrained(LongNet_encoder_path,
- **save_pretrained_kwargs)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_compressor.py b/code/xtuner/model/llava_compressor.py
deleted file mode 100644
index bf2121b746ce8ae067beef69fedbab17ae64378b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_compressor.py
+++ /dev/null
@@ -1,995 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import copy
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import numpy as np
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor
- )
-
-from transformers.integrations import is_deepspeed_zero3_enabled
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers import PreTrainedModel, PretrainedConfig
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
-import torch.nn.functional as F
-from torch.nn.init import trunc_normal_
-
-# ====================== Copied from first file ====================== #
-def get_abs_pos(abs_pos, tgt_size):
- """
- Interpolates 1D absolute positional embeddings to a target size.
- This function is modified to handle 1D positional embeddings, which is
- suitable for sequences of tokens that do not form a square grid.
-
- Args:
- abs_pos (torch.Tensor): The absolute positional embedding tensor of shape (N, C),
- where N is the original sequence length and C is the embedding dim.
- tgt_size (int): The target sequence length.
- Returns:
- torch.Tensor: The interpolated positional embedding tensor of shape (tgt_size, C).
- """
- src_size = abs_pos.size(0)
- dtype = abs_pos.dtype
-
- if src_size == tgt_size:
- return abs_pos
-
- # For 1D interpolation, input tensor to F.interpolate should be (N, C, L)
- # We reshape our (L, C) tensor to (1, C, L)
- interp_input = abs_pos.float().permute(1, 0).unsqueeze(0)
-
- # Perform linear interpolation
- interp_output = F.interpolate(
- interp_input,
- size=tgt_size,
- mode='linear',
- align_corners=False,
- )
-
- # Reshape back to (L_new, C)
- interpolated_pos = interp_output.squeeze(0).permute(1, 0).to(dtype)
-
- return interpolated_pos
-
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h)
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
- emb = np.concatenate([emb_h, emb_w], axis=1)
- return emb
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega
-
- pos = pos.reshape(-1)
- out = np.einsum('m,d->md', pos, omega)
-
- emb_sin = np.sin(out)
- emb_cos = np.cos(out)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1)
- return emb
-
-# Step 1: Create a configuration class for the Resampler
-class ResamplerConfig(PretrainedConfig):
- """
- Configuration class for the Resampler module.
- """
- model_type = "resampler"
- _auto_class = 'AutoConfig'
- def __init__(
- self,
- grid_size,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=nn.LayerNorm,
- **kwargs
- ):
- self.grid_size = grid_size
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.kv_dim = kv_dim
- # self.hidden_act = hidden_act
- self.norm_layer = norm_layer
- super().__init__(**kwargs)
-
-class Resampler(PreTrainedModel):
- _auto_class = 'AutoModel'
- config_class = ResamplerConfig
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
-
- def __init__(
- self,
- config: ResamplerConfig
- ):
- super().__init__(config)
- self.gradient_checkpointing = False
-
- self.num_queries = config.grid_size
- self.embed_dim = config.embed_dim
- self.num_heads = config.num_heads
- kv_dim = config.kv_dim
- norm_layer = config.norm_layer
-
- # REMOVED: Positional embedding initialization
- self.query = nn.Parameter(torch.zeros(self.num_queries, self.embed_dim))
- self.query.data.normal_(mean=0.0, std=0.02)
-
- if kv_dim is not None and kv_dim != self.embed_dim:
- self.kv_proj = nn.Linear(kv_dim, self.embed_dim, bias=False)
- else:
- self.kv_proj = nn.Identity()
-
- self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
- self.ln_q = norm_layer(self.embed_dim)
- self.ln_kv = norm_layer(self.embed_dim)
-
- nn.init.constant_(self.ln_q.bias, 0)
- nn.init.constant_(self.ln_q.weight, 1.0)
- nn.init.constant_(self.ln_kv.bias, 0)
- nn.init.constant_(self.ln_kv.weight, 1.0)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def init_weights(self):
- self.query.data.normal_(mean=0.0, std=0.02)
- nn.init.constant_(self.ln_q.bias, 0)
- nn.init.constant_(self.ln_q.weight, 1.0)
- nn.init.constant_(self.ln_kv.bias, 0)
- nn.init.constant_(self.ln_kv.weight, 1.0)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, Resampler):
- module.gradient_checkpointing = value
-
- def forward(self, x, attn_mask=None, text=None):
- Q = self.query
- x = self.kv_proj(x)
- x = self.ln_kv(x)
- Q = self.ln_q(Q)
-
- # REMOVED: Positional embedding interpolation and addition
- out, attn = self.attn(
- Q.unsqueeze(0).expand(x.size(0), Q.size(0), Q.size(1)),
- x,
- x,
- attn_mask=attn_mask
- )
- return out, attn
-
-# ====================== End of copied code ====================== #
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- compressor_grid_size=2, # New parameter
- compressor_embed_dim=512, # New parameter
- prefusion_layer_num=4): # New parameter for prefusion layers
- super().__init__()
-
- self.enable_long_net = enable_long_net
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = True
- self.freeze_long_net = True
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- self.compressor_grid_size = compressor_grid_size
- self.compressor_embed_dim = compressor_embed_dim
- self.prefusion_layer_num = prefusion_layer_num
-
- # Build compressor
- compressor_config = ResamplerConfig(
- grid_size=compressor_grid_size,
- embed_dim=compressor_embed_dim,
- kv_dim= None,
- num_heads= 8, # Default value, can be adjusted
- )
-
-
-
- self.compressor = Resampler(
- config=compressor_config
- )
- self.compressor.init_weights()
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- # Build prefusion layers
- temps = copy.deepcopy(self.llm.model.layers[:prefusion_layer_num])
- self.prefusion_layers = nn.ModuleList(temps)
- del temps
-
- # self.prefusion_layers=nn.ModuleList([Qwen2DecoderLayer(self.llm.config,layer_idx=i) for i in range(self.prefusion_layer_num)])
- self.prefusion_layers.to(self.llm.dtype)
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
-
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
-
- # Move to correct dtype and device
- self.compressor = self.compressor.to(self.llm.dtype)
-
-
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
-
- self.projector.enable_input_require_grads()
- # for layer in self.prefusion_layers:
- # if hasattr(layer, 'enable_input_require_grads'):
- # layer.enable_input_require_grads()
- # else:
- # layer.get_input_embeddings().register_forward_hook(
- # make_inputs_require_grad)
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if pretrained_pth is not None: # load the pretrained checkpoint
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}', 'current')
-
- self.visual_select_layer = visual_select_layer
- self._is_init = True
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- self.compressor.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
- # for layer in self.prefusion_layers:
- # layer.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- self.compressor.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- # for layer in self.prefusion_layers:
- # layer.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Compressor and Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'compressor.' in k})
-
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
- # Step 4. Prefusion layers
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'prefusion_layers.' in k})
-
- # Step 5. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
- if self.enable_long_net:
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"]
- feat_to_proj = long_net_output.permute(1, 0, 2)
-
- # Apply compressor
- compressed_features, _ = self.compressor(feat_to_proj)
- # print_log('compressed_features shape: {}'.format(compressed_features.shape), 'current')
- # Apply projector
- pixel_values = self.projector(compressed_features)
- projected_global_image_features = self.projector(feat_to_proj)
-
- # Apply prefusion layers if any
- if self.prefusion_layer_num > 0:
- # print_log('Applying prefusion layers', 'current')
- input_ids = data['input_ids']
- text_embeddings = self.llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
- padding_mask=(input_ids <= 0)
-
- # attention_mask = data['attention_mask']
- # position_ids = data.get('position_ids', None)
-
- x = torch.cat([projected_global_image_features, pixel_values, text_embeddings], dim=1)
-
- mask=torch.cat((torch.zeros((padding_mask.size(0),projected_global_image_features.size(1)+pixel_values.size(1)),device=padding_mask.device).bool(),padding_mask),dim=1)
-
- # Prepare attention mask for prefusion layers
- if getattr(self.llm, "_use_flash_attention_2", False) or \
- getattr(self.llm.config, "_attn_implementation", "") == "flash_attention_2":
- attention_mask = (~mask).int()
- else:
- attention_mask =_prepare_4d_causal_attention_mask(~mask, (x.size(0), x.size(1)), x, 0)
-
- position_ids = (~mask).int().long().cumsum(-1) - 1
- position_ids.masked_fill_((~mask).int() == 0, 1)
- # Apply prefusion layers
- for layer in self.prefusion_layers:
- x = layer(
- x,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=False,
- )[0]
-
-
- # data['inputs_embeds'] = x
- fusion_text_features=x[:,-1*input_ids.size(1):,:]
- pixel_values=x[:,-1*input_ids.size(1)-1*pixel_values.size(1):-1*input_ids.size(1),:]
- fusion_text_features=fusion_text_features*(~padding_mask).unsqueeze(-1).int()+ text_embeddings*padding_mask.unsqueeze(-1)
- data['text_features'] = fusion_text_features
- # print_log('text_features shape: {}'.format(fusion_text_features.shape), 'current')
-
- data['pixel_values'] = pixel_values
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
- outputs = self.llm(**data)
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- self.projector.save_pretrained(projector_path,
- **save_pretrained_kwargs)
- # compressor
- compressor_path = osp.join(save_dir, 'compressor')
- print_log(f'Saving compressor to {compressor_path}', 'current')
- self.compressor.save_pretrained(compressor_path,
- **save_pretrained_kwargs)
-
- # Prefusion layers
- if self.prefusion_layer_num > 0:
- prefusion_path = osp.join(save_dir, 'prefusion_layers')
- print_log(f'Saving prefusion layers to {prefusion_path}', 'current')
- torch.save(self.prefusion_layers.state_dict(),
- osp.join(prefusion_path, 'prefusion_layers.bin'))
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- self.LongNet_encoder.save_pretrained(LongNet_encoder_path,
- **save_pretrained_kwargs)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- COMPRESSOR_MAPPING = {
- 'compressor': 'compressor'
- }
- PREFUSION_MAPPING = {
- 'prefusion_layers': 'prefusion_layers'
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- compressor_state_dict = self.compressor.state_dict()
- compressor_state_dict = convert_state_dict_to_hf(
- compressor_state_dict, COMPRESSOR_MAPPING)
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- prefusion_state_dict = self.prefusion_layers.state_dict()
- prefusion_state_dict = convert_state_dict_to_hf(
- prefusion_state_dict, PREFUSION_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **compressor_state_dict,
- **prefusion_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- COMPRESSOR_MAPPING = {
- 'compressor': 'model.compressor'
- }
- PREFUSION_MAPPING = {
- 'prefusion_layers': 'model.prefusion_layers'
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- compressor_state_dict = self.compressor.state_dict()
- compressor_state_dict = convert_state_dict_to_hf(
- compressor_state_dict, COMPRESSOR_MAPPING)
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- prefusion_state_dict = self.prefusion_layers.state_dict()
- prefusion_state_dict = convert_state_dict_to_hf(
- prefusion_state_dict, PREFUSION_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **compressor_state_dict,
- **prefusion_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_dim_reducer.py b/code/xtuner/model/llava_dim_reducer.py
deleted file mode 100644
index e5329cf7ee1df04b45ed9f58f5a6b00d32646942..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_dim_reducer.py
+++ /dev/null
@@ -1,1001 +0,0 @@
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as cp
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor,
- PreTrainedModel, PretrainedConfig)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-from .torchscale.model.LongNet import make_longnet_from_name
-
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-
-class ReducerConfig(PretrainedConfig):
- model_type = 'Reducer'
- _auto_class = 'AutoConfig'
-
- def __init__(
- self,
- in_tokens=4096,
- out_tokens=2048,
- hidden_tokens=1024,
- num_queries = 2048,
- num_heads = 8,
- hidden_size = 3584,
- kernel_size=4,
- stride=4,
- **kwargs,
- ):
- self.in_tokens = in_tokens
- self.out_tokens = out_tokens
- self.hidden_tokens = hidden_tokens
- self.kernel_size = kernel_size,
- self.stride = stride,
- if self.hidden_tokens is None:
- self.hidden_tokens = max(self.in_tokens // 2, self.out_tokens)
-
- self.hidden_size = hidden_size
- self.num_queries = num_queries
- self.num_heads = num_heads
- super().__init__(**kwargs)
-
-class VisualTokenConvReducer(PreTrainedModel):
- """
- Wraps a Conv1d reducer with activation-checkpointing and input-grad support.
- Handles (B, T, D) inputs by transposing internally for Conv1d.
- """
- supports_gradient_checkpointing = True
-
- def __init__(self, config: ReducerConfig):
- super().__init__(config)
- self.conv = nn.Conv1d(
- in_channels=config.hidden_size,
- out_channels=config.hidden_size,
- kernel_size=config.kernel_size,
- stride=config.stride
- )
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, inputs, output):
- output.requires_grad_(True)
- self.conv.register_forward_hook(make_inputs_require_grad)
-
- def forward(self, x):
- """
- x: [batch, tokens_in, features]
- returns: [batch, tokens_out, features]
- """
- # x shape: [B, T_in, D], Conv1d expects [B, D, T_in]
- x = x.transpose(1, 2)
-
- if self.gradient_checkpointing and self.training:
- out = cp.checkpoint(self.conv, x)
- else:
- out = self.conv(x)
-
- # Output from conv is [B, D, T_out], transpose back to [B, T_out, D]
- out = out.transpose(1, 2)
- return out
-
-
-class VisualTokenMLPReducer(PreTrainedModel):
- """
- MLP-based token reducer. Handles (B, T, D) inputs by transposing internally.
- """
- base_model_prefix = "visual_token_mlp_reducer"
- supports_gradient_checkpointing = True
-
- def __init__(self, cfg:ReducerConfig):
- super().__init__(cfg)
- self.in_tokens = cfg.in_tokens
- self.out_tokens = cfg.out_tokens
- self.hidden_tokens = cfg.hidden_tokens
-
- # Two-layer MLP on the token dimension
- self.fc1 = nn.Linear(self.in_tokens, self.hidden_tokens, bias=True)
- self.act = nn.GELU()
- self.fc2 = nn.Linear(self.hidden_tokens, self.out_tokens, bias=True)
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, inputs, output):
- output.requires_grad_(True)
- self.fc1.register_forward_hook(make_inputs_require_grad)
-
- def forward(self, x):
- """
- x: [batch, tokens_in, features]
- returns: [batch, tokens_out, features]
- """
- B, T, D = x.shape
- # Permute to [B, D, T] to apply MLP on the token dimension
- x = x.transpose(1, 2)
-
- # Fold features into batch dim
- x = x.reshape(B * D, T)
-
- if self.gradient_checkpointing and self.training:
- out = cp.checkpoint(self._mlp, x)
- else:
- out = self._mlp(x)
-
- # Unfold back to [B, D, T']
- out = out.view(B, D, self.out_tokens)
-
- # Permute back to [B, T', D]
- out = out.transpose(1, 2)
- return out
-
- def _mlp(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.fc2(x)
- return x
-
-
-class VisualTokenAttentionReducer(PreTrainedModel):
- base_model_prefix = "visual_token_attention_reducer"
- supports_gradient_checkpointing = True
-
- def __init__(self, config: ReducerConfig):
- super().__init__(config)
- # M learnable queries:
- self.query_emb = nn.Parameter(torch.randn(config.num_queries, config.hidden_size))
- # cross-attention:
- self.cross_attn = nn.MultiheadAttention(
- embed_dim=config.hidden_size,
- num_heads=config.num_heads,
- batch_first=True # Process (B, T, D) inputs directly
- )
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, inputs, output):
- if isinstance(output, (tuple, list)):
- output[0].requires_grad_(True)
- else:
- output.requires_grad_(True)
- self.cross_attn.register_forward_hook(make_inputs_require_grad)
-
- def forward(self, x):
- """
- x: (B, T_in, D) → we want (B, M_out, D)
- """
- B, T, D = x.shape
- # K and V are the input visual tokens, already in (B, T, D)
- tokens = x
-
- # expand queries to (B, M, D)
- Q = self.query_emb.unsqueeze(0).expand(B, -1, -1)
-
- if self.gradient_checkpointing and self.training:
- out = cp.checkpoint(self._attn, Q, tokens, tokens)
- else:
- out = self._attn(Q, tokens, tokens)
-
- # out: (B, M, D) - no final transpose needed
- return out
-
- def _attn(self, Q, K, V):
- # returns (B, M, D)
- out, _ = self.cross_attn(Q, K, V)
- return out
-
-
-class TextGuidedVisualTokenAttentionReducer(PreTrainedModel):
- """
- An enhanced attention-based token reducer that uses text tokens to guide
- the compression of visual tokens, operating on batch-first tensors.
- """
- base_model_prefix = "text_guided_visual_token_attention_reducer"
- supports_gradient_checkpointing = True
-
- def __init__(self, config: ReducerConfig):
- super().__init__(config)
- self.query_emb = nn.Parameter(
- torch.randn(config.num_queries, config.hidden_size))
-
- self.norm_kv = nn.LayerNorm(config.hidden_size)
- self.cross_attn = nn.MultiheadAttention(
- embed_dim=config.hidden_size,
- num_heads=config.num_heads,
- batch_first=True # Expects (Batch, Seq_len, Dim)
- )
- self.norm_ffn = nn.LayerNorm(config.hidden_size)
- self.ffn = nn.Sequential(
- nn.Linear(config.hidden_size, config.hidden_size * 4),
- nn.GELU(),
- nn.Linear(config.hidden_size * 4, config.hidden_size)
- )
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, inputs, output):
- if isinstance(output, tuple):
- output[0].requires_grad_(True)
- else:
- output.requires_grad_(True)
- self.cross_attn.register_forward_hook(make_inputs_require_grad)
-
- def forward(self, visual_tokens, text_tokens, text_attention_mask=None):
- """
- Performs text-guided reduction of visual tokens.
- Args:
- visual_tokens (torch.Tensor): Visual tokens of shape (B, T_visual, D).
- text_tokens (torch.Tensor): Text token embeddings of shape (B, T_text, D).
- text_attention_mask (torch.Tensor): Mask for text tokens.
- Returns:
- torch.Tensor: Compressed visual tokens of shape (B, M_out, D).
- """
- B, T_visual, D = visual_tokens.shape
-
- # Concatenate along the sequence dimension to form Key (K) and Value (V)
- kv_tokens = torch.cat([visual_tokens, text_tokens], dim=1)
- # print_log(f'kv tokens shape: {kv_tokens.shape}', 'current')
-
- key_padding_mask = None
- if text_attention_mask is not None:
- visual_padding_mask = torch.ones(
- B, T_visual, dtype=torch.bool, device=visual_tokens.device)
- combined_mask = torch.cat([visual_padding_mask, text_attention_mask], dim=1)
- key_padding_mask = (combined_mask == 0)
-
- # Prepare queries (Q) with batch dimension first: (B, M, D)
- queries = self.query_emb.unsqueeze(0).expand(B, -1, -1)
-
- # --- Main Forward Pass ---
- attn_output, _ = self.cross_attn(
- query=queries,
- key=self.norm_kv(kv_tokens),
- value=self.norm_kv(kv_tokens),
- key_padding_mask=key_padding_mask
- )
- queries = queries + attn_output # Residual connection
-
- ffn_output = self.ffn(self.norm_ffn(queries))
- queries = queries + ffn_output # Residual connection
-
- # Final output is already (B, M, D)
- return queries
-
-
-class LLaVAModelWithReducer(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- visual_token_reducer_config=None):
- super().__init__()
- self.enable_long_net = enable_long_net
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- self.freeze_llm = True
- self.freeze_long_net = False
- elif train_stage == '2':
- self.freeze_llm = True
- self.freeze_long_net = True
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.encoder_name = f"LongNet_{2}_layers_{512}_dim"
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name)
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- # Projector
- self.projector_depth = projector_depth
- proj_cfg = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
- self.projector = ProjectorModel(proj_cfg).to(self.llm.dtype)
-
- # Visual Token Reducer
- if visual_token_reducer_config:
- cfg = visual_token_reducer_config
- reducer_type = cfg.get('type', 'attention') # Default to attention
-
- if reducer_type == 'conv':
- reducer_cfg = ReducerConfig(
- hidden_size= self.llm.config.hidden_size,
- kernel_size=cfg['kernel_size'],
- stride=cfg['stride'],
- )
- self.visual_token_reducer = VisualTokenConvReducer(
- reducer_cfg
- ).to(self.llm.dtype)
-
- elif reducer_type == 'attention':
- reducer_cfg = ReducerConfig(
- in_tokens=cfg['in_tokens'],
- out_tokens=cfg['out_tokens'],
- hidden_tokens=cfg.get('hidden_tokens', None),
- num_heads=cfg.get('num_heads', 8),
- num_queries=cfg.get('num_queries', 2048),
- hidden_size=self.llm.config.hidden_size
- )
- self.visual_token_reducer = VisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype)
- elif reducer_type == 'text_guided_attention':
- reducer_cfg = ReducerConfig(
- in_tokens=cfg['in_tokens'],
- out_tokens=cfg['out_tokens'],
- hidden_tokens=cfg.get('hidden_tokens', None),
- num_heads=cfg.get('num_heads', 8),
- num_queries=cfg.get('num_queries', 2048),
- hidden_size=self.llm.config.hidden_size
- )
- self.visual_token_reducer = TextGuidedVisualTokenAttentionReducer(reducer_cfg).to(self.llm.dtype)
-
- elif reducer_type == 'mlp':
- reducer_cfg = ReducerConfig(
- in_tokens=cfg['in_tokens'],
- out_tokens=cfg['out_tokens'],
- hidden_tokens=cfg.get('hidden_tokens', None),
- hidden_size=self.llm.config.hidden_size
- )
- self.visual_token_reducer = VisualTokenMLPReducer(reducer_cfg).to(self.llm.dtype)
-
- else:
- raise ValueError(f"Unknown reducer type: {reducer_type}. "
- "Supported types: 'conv', 'attention', 'mlp', 'text_guided_attention'")
-
- # Freezing
- if self.freeze_llm:
- self.llm.requires_grad_(False)
- if getattr(self, 'freeze_long_net', False):
- self.LongNet_encoder.requires_grad_(False)
-
- # Activation / gradient checkpointing
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
- self.projector.enable_input_require_grads()
- if hasattr(self, 'visual_token_reducer'):
- self.visual_token_reducer.enable_input_require_grads()
- self.gradient_checkpointing_enable()
-
- # LoRA
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = visual_encoder_lora is not None
- if self.use_llm_lora:
- print_log('Using LoRA for LLM', 'current')
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- # Load pretrained
- if pretrained_pth is not None:
- state = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(state, strict=False)
- print_log(f'Loaded pretrained weights from {pretrained_pth}', 'current')
-
- self.visual_select_layer = visual_select_layer
- self._is_init = True
- self.is_first_iter = True
-
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
- if hasattr(self, 'visual_token_reducer'):
- self.visual_token_reducer.gradient_checkpointing = True
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- if hasattr(self, 'visual_token_reducer'):
- self.visual_token_reducer.gradient_checkpointing = False
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
-
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
-
- # Step 5. Visual Token Reducer
- if hasattr(self, 'visual_token_reducer'):
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'visual_token_reducer.' in k})
-
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- is_text_guided_reducer = isinstance(
- getattr(self, 'visual_token_reducer', None), TextGuidedVisualTokenAttentionReducer
- )
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
- feat_to_proj = self.LongNet_encoder(src_tokens=None,
- token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"].permute(1, 0, 2)
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype))
-
- if hasattr(self, 'visual_token_reducer'):
- if is_text_guided_reducer:
- # Get text embeddings and attention mask for the guided reducer
- input_ids = data['input_ids']
- text_attention_mask = data.get('attention_mask')
- text_embeddings = self.llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
-
- pixel_values = self.visual_token_reducer(
- pixel_values, text_embeddings, text_attention_mask
- )
-
- else:
- # Input to reducer is now (B, T, D)
- pixel_values = self.visual_token_reducer(pixel_values)
- # print_log(f'Visual tokens reduced to shape: {pixel_values.shape}', 'current')
-
- data['pixel_values'] = pixel_values
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- self.llm.half()
- if self.use_llm_lora:
- llm_adapter = osp.join(save_dir, 'llm_adapter')
- self.llm.save_pretrained(llm_adapter, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
- self.llm.save_pretrained(save_dir, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_adapter = osp.join(save_dir, 'visual_encoder_adapter')
- self.visual_encoder.save_pretrained(visual_adapter, **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- vis_dir = osp.join(save_dir, 'visual_encoder')
- BUILDER.build(cfg.image_processor).save_pretrained(vis_dir, **save_pretrained_kwargs)
- self.visual_encoder.save_pretrained(vis_dir, **save_pretrained_kwargs)
-
- # Projector
- proj_dir = osp.join(save_dir, 'projector')
- self.projector.save_pretrained(proj_dir, **save_pretrained_kwargs)
-
- # Visual Token Reducer
- if hasattr(self, 'visual_token_reducer'):
- red_dir = osp.join(save_dir, 'visual_token_reducer')
- self.visual_token_reducer.save_pretrained(red_dir, **save_pretrained_kwargs)
-
- # LongNet_encoder
- longnet_dir = osp.join(save_dir, 'LongNet_encoder')
- self.LongNet_encoder.save_pretrained(longnet_dir, **save_pretrained_kwargs)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- REDUCER_MAPPING = {
- 'query_emb': 'visual_token_reducer.query_emb',
- 'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight',
- 'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias',
- 'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight',
- 'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias'
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- red_state = convert_state_dict_to_hf(
- self.visual_token_reducer.state_dict(), REDUCER_MAPPING
- ) if hasattr(self, 'visual_token_reducer') else {}
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict,
- **red_state
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
- REDUCER_MAPPING = {
- 'query_emb': 'visual_token_reducer.query_emb',
- 'cross_attn.in_proj_weight': 'visual_token_reducer.cross_attn.in_proj_weight',
- 'cross_attn.in_proj_bias': 'visual_token_reducer.cross_attn.in_proj_bias',
- 'cross_attn.out_proj.weight':'visual_token_reducer.cross_attn.out_proj.weight',
- 'cross_attn.out_proj.bias': 'visual_token_reducer.cross_attn.out_proj.bias'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- red_state = convert_state_dict_to_hf(
- self.visual_token_reducer.state_dict(), REDUCER_MAPPING
- ) if hasattr(self, 'visual_token_reducer') else {}
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict,
- **red_state # add visual token reducer state
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_divprune.py b/code/xtuner/model/llava_divprune.py
deleted file mode 100644
index 86c927d7773119ee4bd8e519ec4e3384962d7ec6..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_divprune.py
+++ /dev/null
@@ -1,743 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-import torch.nn.functional as F
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- divprune_ratio=0.1):
- super().__init__()
-
- self.enable_long_net = enable_long_net
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
- self.projector.enable_input_require_grads()
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if pretrained_pth is not None:
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}', 'current')
-
- self.visual_select_layer = visual_select_layer
- self.divprune_ratio = divprune_ratio
- self._is_init = True
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def pairwise_cosine_similarity(self, matrix):
- norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
- cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
- return cosine_similarity
-
- def pairwise_l1_distance(matrix: torch.Tensor) -> torch.Tensor:
- """
- Compute the full pairwise L1 (Manhattan) distance matrix
- for an [N, D] tensor.
- """
- # torch.cdist with p=1 computes L1 distance
- return torch.cdist(matrix, matrix, p=1)
-
-
- def DivPrune(self, visual_feature_vectors, image_feature_length, cosine_matrix=None, threshold_ratio=0.1):
- threshold_terms = int(round(threshold_ratio * image_feature_length))
- if cosine_matrix is None:
- cosine_matrix = 1.0 - (self.pairwise_l1_distance(visual_feature_vectors))
-
- s = torch.empty(threshold_terms, dtype=torch.long, device=visual_feature_vectors.device)
- for i in range(threshold_terms):
- if i == 0:
- m2 = cosine_matrix
- else:
- m2 = torch.index_select(cosine_matrix, 0, torch.index_select(s, 0, torch.arange(0, i, device=cosine_matrix.device)))
-
- if i == 0:
- scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
- else:
- scores = torch.min(m2, dim=0).values
-
- phrase_to_add_idx = torch.argmax(scores)
- s[i] = phrase_to_add_idx
- return s, cosine_matrix
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
- if self.enable_long_net:
- # LongNet expects (seq_len, batch, dim)
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"]
- # Permute back to (batch, seq_len, dim)
- feat_to_proj = long_net_output.permute(1, 0, 2)
-
- # --- DivPrune Integration for Batch Processing ---
- if self.divprune_ratio < 1.0:
- # This assumes feat_to_proj has a shape of [batch_size, num_tokens, feature_dim].
- # It iterates through each item in the batch, prunes it, and stacks the results.
- # This works if num_tokens is the same for all items in the batch, which is a
- # standard assumption for batched processing.
- pruned_batch_features = []
- for visual_tokens in feat_to_proj: # Iterate over the batch dimension
- img_feature_len = visual_tokens.shape[0]
- selected_indices, _ = self.DivPrune(
- visual_tokens,
- img_feature_len,
- threshold_ratio=self.divprune_ratio
- )
- selected_indices = torch.sort(selected_indices).values
- pruned_features = visual_tokens[selected_indices]
- pruned_batch_features.append(pruned_features)
-
- # Stack the list of pruned tensors back into a single batch tensor
- feat_to_proj = torch.stack(pruned_batch_features, dim=0)
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype))
- data['pixel_values'] = pixel_values
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- # outputs.logits.shape (1, 1094, 152064) for Qwen
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- self.projector.save_pretrained(projector_path,
- **save_pretrained_kwargs)
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- self.LongNet_encoder.save_pretrained(LongNet_encoder_path,
- **save_pretrained_kwargs)
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_fewer.py b/code/xtuner/model/llava_fewer.py
deleted file mode 100644
index c0333240c3c598623f1a3861aa3249cbc205d646..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_fewer.py
+++ /dev/null
@@ -1,745 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-import os
-from safetensors.torch import load_file, save_file
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-import torch.nn.functional as F
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- long_net_pth=None,
- projector_pth = None,
- kept_index = 100 # keeped first token number
-
- ):
- super().__init__()
-
- self.enable_long_net = enable_long_net
-
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = False #False
- self.freeze_long_net = True#False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name) # , drop_path_rate=0.3, dropout=0.3, segment_length=1024
-
- self.kept_index = kept_index
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
-
- self.projector.enable_input_require_grads()
- # self.LongNet_encoder.enable_input_require_grads()
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = None
- if self.use_llm_lora:
- print_log(f"Building lora {llm_lora.__str__}", "current")
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- # ── 2) Load projector + LongNet from safetensors ────────────────────────
- if long_net_pth is not None:
- print_log(f"Loading LongNet from {long_net_pth}", "current")
- ln_sd = load_file(long_net_pth, device="cpu")
- self.LongNet_encoder.load_state_dict(ln_sd, strict=False)
- self.LongNet_encoder.to(self.llm.dtype)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- # ── 3) Optionally load a classic float checkpoint and filter mismatches ──
- if pretrained_pth is not None:
- sd = guess_load_checkpoint(pretrained_pth)
- model_sd = self.state_dict()
- filtered = {
- k: v for k, v in sd.items()
- if k in model_sd and model_sd[k].shape == v.shape
- }
- missing, unexpected = self.load_state_dict(filtered, strict=False)
- print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
- print_log(f" missing: {missing}", "current")
- print_log(f" unexpected:{unexpected}", "current")
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
-
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
- if self.enable_long_net:
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"] # shape: (img_num, 1, 1024)
- # wl - output shape (img_num, 1, 512)
- feat_to_proj = long_net_output.permute(1, 0, 2) # permuted shape: [1, img_num, 512]
-
- feat_to_proj = feat_to_proj[:,:self.kept_index, :]
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584]
-
- data['pixel_values'] = pixel_values # shape: [1, patch_num, 3584] # shape: [1, 576, 4096]
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- # outputs.logits.shape (1, 1094, 152064) for Qwen
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
-
- if self.LongNet_encoder is not None:
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- # Ensure the target directory exists
- os.makedirs(LongNet_encoder_path, exist_ok=True)
-
- # Define the full path for the weights file
- output_path = os.path.join(LongNet_encoder_path, 'longnet_encoder.safetensors')
-
- # Save the state dictionary using safetensors
- save_file(self.LongNet_encoder.state_dict(), output_path)
-
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_fusion_compressor.py b/code/xtuner/model/llava_fusion_compressor.py
deleted file mode 100644
index 73db8296c686fc863f48c3f7a9323843788d0c41..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_fusion_compressor.py
+++ /dev/null
@@ -1,1015 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import copy
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import numpy as np
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor
- )
-
-from transformers.integrations import is_deepspeed_zero3_enabled
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from transformers import PreTrainedModel, PretrainedConfig
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-from .torchscale.model.LongNet import make_longnet_from_name
-from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
-import torch.nn.functional as F
-from torch.nn.init import trunc_normal_
-
-# ====================== Copied from first file ====================== #
-def get_abs_pos(abs_pos, tgt_size):
- """
- Interpolates 1D absolute positional embeddings to a target size.
- This function is modified to handle 1D positional embeddings, which is
- suitable for sequences of tokens that do not form a square grid.
-
- Args:
- abs_pos (torch.Tensor): The absolute positional embedding tensor of shape (N, C),
- where N is the original sequence length and C is the embedding dim.
- tgt_size (int): The target sequence length.
- Returns:
- torch.Tensor: The interpolated positional embedding tensor of shape (tgt_size, C).
- """
- src_size = abs_pos.size(0)
- dtype = abs_pos.dtype
-
- if src_size == tgt_size:
- return abs_pos
-
- # For 1D interpolation, input tensor to F.interpolate should be (N, C, L)
- # We reshape our (L, C) tensor to (1, C, L)
- interp_input = abs_pos.float().permute(1, 0).unsqueeze(0)
-
- # Perform linear interpolation
- interp_output = F.interpolate(
- interp_input,
- size=tgt_size,
- mode='linear',
- align_corners=False,
- )
-
- # Reshape back to (L_new, C)
- interpolated_pos = interp_output.squeeze(0).permute(1, 0).to(dtype)
-
- return interpolated_pos
-
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h)
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
- emb = np.concatenate([emb_h, emb_w], axis=1)
- return emb
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega
-
- pos = pos.reshape(-1)
- out = np.einsum('m,d->md', pos, omega)
-
- emb_sin = np.sin(out)
- emb_cos = np.cos(out)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1)
- return emb
-
-# Step 1: Create a configuration class for the Resampler
-class ResamplerConfig(PretrainedConfig):
- """
- Configuration class for the Resampler module.
- """
- model_type = "resampler"
- _auto_class = 'AutoConfig'
- def __init__(
- self,
- grid_size,
- embed_dim,
- num_heads,
- kv_dim=None,
- norm_layer=nn.LayerNorm,
- **kwargs
- ):
- self.grid_size = grid_size
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.kv_dim = kv_dim
- # self.hidden_act = hidden_act
- self.norm_layer = norm_layer
- super().__init__(**kwargs)
-
-class Resampler(PreTrainedModel):
- _auto_class = 'AutoModel'
- config_class = ResamplerConfig
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
-
- def __init__(
- self,
- config: ResamplerConfig
- ):
- super().__init__(config)
- self.gradient_checkpointing = False
-
- self.num_queries = config.grid_size
- self.embed_dim = config.embed_dim
- self.num_heads = config.num_heads
- kv_dim = config.kv_dim
- norm_layer = config.norm_layer
-
- # REMOVED: Positional embedding initialization
- self.query = nn.Parameter(torch.zeros(self.num_queries, self.embed_dim))
- self.query.data.normal_(mean=0.0, std=0.02)
-
- if kv_dim is not None and kv_dim != self.embed_dim:
- self.kv_proj = nn.Linear(kv_dim, self.embed_dim, bias=False)
- else:
- self.kv_proj = nn.Identity()
-
- self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
- self.ln_q = norm_layer(self.embed_dim)
- self.ln_kv = norm_layer(self.embed_dim)
-
- nn.init.constant_(self.ln_q.bias, 0)
- nn.init.constant_(self.ln_q.weight, 1.0)
- nn.init.constant_(self.ln_kv.bias, 0)
- nn.init.constant_(self.ln_kv.weight, 1.0)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def init_weights(self):
- self.query.data.normal_(mean=0.0, std=0.02)
- nn.init.constant_(self.ln_q.bias, 0)
- nn.init.constant_(self.ln_q.weight, 1.0)
- nn.init.constant_(self.ln_kv.bias, 0)
- nn.init.constant_(self.ln_kv.weight, 1.0)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, Resampler):
- module.gradient_checkpointing = value
-
- def forward(self, x, attn_mask=None, text=None):
- Q = self.query
- x = self.kv_proj(x)
- x = self.ln_kv(x)
- Q = self.ln_q(Q)
-
- # REMOVED: Positional embedding interpolation and addition
- out, attn = self.attn(
- Q.unsqueeze(0).expand(x.size(0), Q.size(0), Q.size(1)),
- x,
- x,
- attn_mask=attn_mask
- )
- return out, attn
-
-# ====================== End of copied code ====================== #
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- enable_long_net=True,
- compressor_grid_size = 2048,
-
- prefusion_layer_num=2,
- image_only = False,
- ): # New parameter for prefusion layers
- super().__init__()
-
- self.enable_long_net = enable_long_net
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = True
- self.freeze_long_net = True
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(2, 512)
- self.LongNet_encoder = make_longnet_from_name(self.encoder_name)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- self.compressor_grid_size = compressor_grid_size
- self.compressor_embed_dim = self.llm.config.hidden_size
- self.prefusion_layer_num = prefusion_layer_num
-
-
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- # Build prefusion layers
- temps = copy.deepcopy(self.llm.model.layers[:prefusion_layer_num])
- self.prefusion_layers = nn.ModuleList(temps)
- del temps
-
- self.image_only = image_only
-
- self.prefusion_layers.to(self.llm.dtype)
-
- self.query_emb = nn.Parameter(torch.randn(self.compressor_grid_size, self.llm.config.hidden_size))
- nn.init.normal_(self.query_emb, mean=0.0, std=0.02)
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
-
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
-
-
- # Move to correct dtype and device
- # self.compressor = self.compressor.to(self.llm.dtype)
-
-
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
-
- self.projector.enable_input_require_grads()
- # for layer in self.prefusion_layers:
- # if hasattr(layer, 'enable_input_require_grads'):
- # layer.enable_input_require_grads()
- # else:
- # layer.get_input_embeddings().register_forward_hook(
- # make_inputs_require_grad)
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if pretrained_pth is not None: # load the pretrained checkpoint
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}', 'current')
-
- self.visual_select_layer = visual_select_layer
- self._is_init = True
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.compressor.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
- # for layer in self.prefusion_layers:
- # layer.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.compressor.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- # for layer in self.prefusion_layers:
- # layer.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Compressor and Projector
- # to_return.update(
- # {k: v
- # for k, v in state_dict.items() if 'compressor.' in k})
-
- to_return.update({k: v for k, v in state_dict.items() if 'query_emb' in k})
-
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
- # Step 4. Prefusion layers
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'prefusion_layers.' in k})
-
- # Step 5. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
- if self.enable_long_net:
- long_net_output = self.LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj.permute(1, 0, 2))["encoder_out"]
- feat_to_proj = long_net_output.permute(1, 0, 2)
-
- # Apply projector
- # pixel_values = self.projector(feat_to_proj)
- projected_global_image_features = self.projector(feat_to_proj)
-
- # Apply prefusion layers if any
- if self.prefusion_layer_num > 0:
- # print_log('Applying prefusion layers', 'current')
- input_ids = data['input_ids']
- # attention_mask = data['attention_mask']
- # position_ids = data.get('position_ids', None)
- B, D, T_visual = projected_global_image_features.shape
- # Q = self.query_emb.unsqueeze(1).expand(-1, B, -1)
- Q = self.query_emb.unsqueeze(0).expand(B, -1, -1)
- padding_mask=(input_ids <= 0)
-
- if self.image_only:
-
- x = torch.cat([projected_global_image_features, Q], dim=1)
- mask=torch.cat((
- torch.zeros((padding_mask.size(0),projected_global_image_features.size(1)),
- device=padding_mask.device).bool(),
- torch.ones(padding_mask.size(0), Q.size(1),
- device = padding_mask.device).bool()
- ),
- dim=1)
-
- else:
-
- text_embeddings = self.llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
- B, T_img, H = projected_global_image_features.shape
- T_txt = text_embeddings.size(1)
- G = self.query_emb.size(0)
-
- Q = self.query_emb.unsqueeze(0).expand(B, -1, -1) # (B, G, H)
- x = torch.cat([projected_global_image_features, # (B, T_img, H)
- text_embeddings, # (B, T_txt, H)
- Q], dim=1) # (B, G, H)
- S = x.size(1)
-
- # --- key padding mask: only IMAGE tokens are valid keys ---
- # mask=True means "mask this position as a key" (will be inverted below)
- mask_img = torch.zeros(B, T_img, dtype=torch.bool, device=x.device) # keep as keys
- mask_txt = torch.ones( B, T_txt, dtype=torch.bool, device=x.device) # mask as keys
- mask_q = torch.ones( B, G, dtype=torch.bool, device=x.device) # mask as keys
- mask = torch.cat([mask_img, mask_txt, mask_q], dim=1) # (B, S)
-
- # --- attention mask for the prefusion layers ---
- # FA2 path: 2D key-padding mask, 1 = valid key, 0 = masked key
- if getattr(self.llm, "_use_flash_attention_2", False) or \
- getattr(self.llm.config, "_attn_implementation", "") == "flash_attention_2":
- attention_mask = (~mask).int() # (B, S)
- else:
- attention_mask = _prepare_4d_causal_attention_mask(~mask, (B, S), x, 0)
-
- # --- position ids: meaningful for image tokens, neutral for others ---
- position_ids = attention_mask.long().cumsum(-1) - 1 # (B, S)
- position_ids.masked_fill_(attention_mask == 0, 1)
-
- # run the prefusion layers
- for layer in self.prefusion_layers:
- x = layer(
- x,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=False,
- )[0]
-
- # keep only the G compressed tokens (the queries)
- pixel_values = x[:, -G:, :]
- data['pixel_values'] = pixel_values
- # print_log('text_features shape: {}'.format(fusion_text_features.shape), 'current')
- else:
- data['pixel_values'] = pixel_values
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
- outputs = self.llm(**data)
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- self.projector.save_pretrained(projector_path,
- **save_pretrained_kwargs)
- # compressor
- compressor_path = osp.join(save_dir, 'compressor')
- print_log(f'Saving compressor to {compressor_path}', 'current')
- self.compressor.save_pretrained(compressor_path,
- **save_pretrained_kwargs)
-
- # Prefusion layers
- if self.prefusion_layer_num > 0:
- prefusion_path = osp.join(save_dir, 'prefusion_layers')
- print_log(f'Saving prefusion layers to {prefusion_path}', 'current')
- torch.save(self.prefusion_layers.state_dict(),
- osp.join(prefusion_path, 'prefusion_layers.bin'))
-
- # LongNet_encoder
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- self.LongNet_encoder.save_pretrained(LongNet_encoder_path,
- **save_pretrained_kwargs)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- COMPRESSOR_MAPPING = {
- 'compressor': 'compressor'
- }
- PREFUSION_MAPPING = {
- 'prefusion_layers': 'prefusion_layers'
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- # compressor_state_dict = self.compressor.state_dict()
- # compressor_state_dict = convert_state_dict_to_hf(
- # compressor_state_dict, COMPRESSOR_MAPPING)
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- prefusion_state_dict = self.prefusion_layers.state_dict()
- prefusion_state_dict = convert_state_dict_to_hf(
- prefusion_state_dict, PREFUSION_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- # **compressor_state_dict,
- **prefusion_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- COMPRESSOR_MAPPING = {
- 'compressor': 'model.compressor'
- }
- PREFUSION_MAPPING = {
- 'prefusion_layers': 'model.prefusion_layers'
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- # compressor_state_dict = self.compressor.state_dict()
- # compressor_state_dict = convert_state_dict_to_hf(
- # compressor_state_dict, COMPRESSOR_MAPPING)
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- prefusion_state_dict = self.prefusion_layers.state_dict()
- prefusion_state_dict = convert_state_dict_to_hf(
- prefusion_state_dict, PREFUSION_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- # **compressor_state_dict,
- **prefusion_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_longvit.py b/code/xtuner/model/llava_longvit.py
deleted file mode 100644
index 033c5c32ae22e7516dfbb6c846c197d71baf4172..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_longvit.py
+++ /dev/null
@@ -1,1408 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.distributed as dist # === MOD ===
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor,
- AutoTokenizer)
-from transformers.integrations import is_deepspeed_zero3_enabled
-import os
-from safetensors.torch import load_file, save_file
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-import torch.nn.functional as F
-from .sparse_token_merge import SparsePatchMerging
-
-# ===== 在类前或类内其它位置都可以:新增一个探测函数 =====
-def _detect_qwen_major_version(llm) -> int:
- """
- 返回 3 表示 Qwen3,2 表示 Qwen2,0 表示未知/其它。
- 优先用 config.model_type,其次回退到类名字符串。
- """
- base = llm.model if hasattr(llm, "model") else llm
- cfg = getattr(base, "config", None)
- mt = (getattr(cfg, "model_type", None) or "").lower()
- if mt == "qwen3":
- return 3
- if mt == "qwen2":
- return 2
-
- # 回退:根据类名判别
- cname = base.__class__.__name__.lower()
- if "qwen3" in cname:
- return 3
- if "qwen2" in cname:
- return 2
- return 0
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
- LongNet_encoder = None,
- longnet_k = None,
- long_net_pth=None,
- projector_pth = None,
- perceiver_pth = None,
- #config swin_longnet
- enable_token_merge = True,
- # config for Perceiver Resampler
- use_perceiver_resampler = True,
- perceiver_num_latents=64,
- perceiver_depth=2,
-
- ):
- super().__init__()
-
- enable_long_net = False if LongNet_encoder is None else True
-
- self.enable_long_net = enable_long_net
-
- if enable_long_net:
- print('enable long net')
- else:
- print('disable long net')
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
-
- if train_stage == '0':
- print_log('train_stage == 0', 'current')
- self.freeze_llm = True
- self.freeze_long_net = True
-
- if train_stage == '1':
- print_log('train_stage == 1', 'current')
- self.freeze_llm = True
- self.freeze_long_net = False
-
- elif train_stage == '2':
- print_log('train_stage == 2', 'current')
- self.freeze_llm = False #False
- self.freeze_long_net = False #False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
- # try:
- llm_path = self.llm.config.name_or_path
- self.tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True)
- # Add a pad token if it doesn't exist, which is crucial for decoding labels
- if self.tokenizer.pad_token is None:
- self.tokenizer.pad_token = self.tokenizer.eos_token
- print_log("Tokenizer successfully loaded for debugging.", "current")
- # except Exception as e:
- # self.tokenizer = None
- # print_log(f"Could not load tokenizer for debugging: {e}", "current")
-
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- if self.enable_long_net:
- # self.LongNet_encoder = create_longvit_model(
- LongNet_encoder = self._build_from_cfg_or_module(LongNet_encoder)
- self.LongNet_encoder = LongNet_encoder.to(self.llm.dtype)
-
-
- self.enable_token_merge = enable_token_merge
- if self.enable_token_merge:
- self.token_merge = SparsePatchMerging(
- embed_dim= self.LongNet_encoder.embed_dim if self.enable_long_net and self.LongNet_encoder is not None else hidden_size,
- layernorm_eps=1e-6,
- merge_size= 2
- )
-
- self.projector_depth = projector_depth
- visual_dim = self.LongNet_encoder.embed_dim if self.enable_long_net and self.LongNet_encoder is not None else hidden_size
-
- projector_config = ProjectorConfig(
- visual_hidden_size= visual_dim * 4 if self.enable_token_merge else visual_dim,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- self.use_perceiver_resampler = use_perceiver_resampler
- if self.use_perceiver_resampler:
- # ---- 自动选择 Qwen3 或 Qwen2 的 Perceiver 实现 ----
- self.perceiver_num_latents = perceiver_num_latents
- self.perceiver_depth = perceiver_depth
-
- qwen_major = _detect_qwen_major_version(self.llm)
- print_log(f'using qwen version{qwen_major}', 'current')
- if qwen_major == 3:
- # Qwen3 分支
- try:
- from .qwen3_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- print_log('using qwen3', 'current')
- except Exception as e:
- raise RuntimeError(
- "检测到 Qwen3,但未找到 qwen3_perceiver_resampler,请确认文件存在且 transformers 版本满足要求(>=4.51)。"
- ) from e
- elif qwen_major == 2:
- # Qwen2 分支
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- else:
- warnings.warn(
- "未能确定 Qwen 主版本(既不是 qwen3 也不是 qwen2)。将回退到 Qwen2 的 Perceiver 实现。",
- RuntimeWarning,
- )
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
-
- # 构建并初始化 Perceiver
- self.perceiver = _PR(
- self.llm,
- num_latents=self.perceiver_num_latents,
- depth=self.perceiver_depth,
- ).to(self.llm.dtype)
-
- _init_pr(
- perceiver=self.perceiver,
- llm=self.llm,
- ckpt_hint=getattr(self.llm.config, "_name_or_path", None),
- init_from_layers=self.perceiver.depth,
- layer_offset=0,
- allow_download=False,
- )
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- if self.enable_long_net:
- if self.freeze_long_net:
- print('freeze_long_net')
- self.LongNet_encoder.requires_grad_(False)
- elif longnet_k is not None:
-
- self.LongNet_encoder.requires_grad_(False)
-
- # 重新打开 patch_embed(你重初始化了首层,建议可训练)
- for p in self.LongNet_encoder.patch_embed.parameters():
- p.requires_grad = True
-
- # 打开最后 K 层(假设 depth=12,打开最后 4 层)
- K = longnet_k
- if hasattr(self.LongNet_encoder, "encoder") and hasattr(self.LongNet_encoder.encoder, "layers"):
- for layer in self.LongNet_encoder.encoder.layers[-K:]:
- for p in layer.parameters():
- p.requires_grad = True
-
- # 解冻 LongNet 顶层与全局的 LayerNorm
- for name, param in self.LongNet_encoder.named_parameters():
- if name.startswith("encoder.layer_norm.") or name.startswith("norm."):
- param.requires_grad = True
- else:
- self.LongNet_encoder.requires_grad_(True)
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
- if self.use_perceiver_resampler:
- self.perceiver.enable_input_require_grads()
-
- self.projector.enable_input_require_grads()
-
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = None
- if self.use_llm_lora:
- print_log(f"Building lora {llm_lora.__str__}", "current")
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- # ── 2) Load projector + LongNet from safetensors ────────────────────────
- if long_net_pth is not None:
- print_log(f"Loading LongNet from {long_net_pth}", "current")
- ln_sd = load_file(long_net_pth, device="cpu")
- self.LongNet_encoder.load_state_dict(ln_sd, strict=False)
- self.LongNet_encoder.to(self.llm.dtype)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if perceiver_pth is not None and self.use_perceiver_resampler:
- print_log(f'Loading perceiver from {perceiver_pth}", "current ')
- perceiver_sd = load_file(perceiver_pth, device="cpu")
- self.projector.load_state_dict(perceiver_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- # ── 3) Optionally load a classic float checkpoint and filter mismatches ──
- if pretrained_pth is not None:
- sd = guess_load_checkpoint(pretrained_pth)
- model_sd = self.state_dict()
- filtered = {
- k: v for k, v in sd.items()
- if k in model_sd and model_sd[k].shape == v.shape
- }
- missing, unexpected = self.load_state_dict(filtered, strict=False)
- print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
- print_log(f" missing: {missing}", "current")
- print_log(f" unexpected:{unexpected}", "current")
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
-
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- if self.use_perceiver_resampler:
- self.perceiver.enable_input_require_grads()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- if self.use_perceiver_resampler:
- self.perceiver.disable_gradient_checkpointing()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
-
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. LongNet_encoder
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'LongNet_encoder.' in k})
-
- # Step 5. Perceiver Resampler (unchanged)
- if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None:
- to_return.update({k: v for k, v in state_dict.items() if 'perceiver.' in k})
-
- if getattr(self, 'token_merge', None) is not None:
- to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
- print_log(f"going with long sequence", 'current')
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
-
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
-
- coords = None
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
- if 'coords' in data:
- coords = data['coords'].to(self.llm.dtype)
- # Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor
- coords_t = coords[0] if isinstance(coords, list) else coords
- Bx = feat_to_proj.size(0) # actual batch size of inputs
- if not torch.is_tensor(coords_t):
- raise ValueError("coords must be a Tensor or list[Tensor].")
-
- if coords_t.dim() == 2:
- # [L, 2]
- coords_rc = coords_t
- elif coords_t.dim() == 3:
- # [B, L, 2] -> ensure B matches and either B==1 or all examples share coords
- if coords_t.size(0) != Bx:
- raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
- if Bx == 1:
- coords_rc = coords_t[0]
- else:
- # require same coords across the batch (cheap equality check)
- if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
- raise NotImplementedError(
- "Per-example coords (varying across batch) are not supported by the current "
- "patch-merging/layout path. Use batch size 1 or share coords across the batch."
- )
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
-
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
- else:
- raise RuntimeError
-
- if self.enable_long_net:
-
- long_net_output = self.LongNet_encoder(
- x = feat_to_proj,
- coords = coords,
- all_layer_embed=False
- )
-
- # output length [1, img_num +1, 768]
- long_net_output = long_net_output[-1]
- feat_to_proj = long_net_output[:, 1:, :] # remove class token
- if self.enable_token_merge and 'coords' in data:
- feat_to_proj, _ , _ = self.token_merge(
- x = feat_to_proj,
- coords_rc = self._coords_to_rowcol(coords_rc),
- padmask = torch.zeros(
- [feat_to_proj.size(0), feat_to_proj.size(1)],
- device=feat_to_proj.device,
- ).bool()
- )
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584]
- if self.use_perceiver_resampler and 'input_ids' in data:
-
- # do this only here to void copy embedding layer to preceiver
- text_embeddings = self.llm.get_input_embeddings()(
- data["input_ids"].clamp(min=0)
- ).to(self.llm.dtype).detach()
-
- compressed = self.perceiver(
- # input_ids = data["input_ids"],
- text_embeddings=text_embeddings,
- attention_mask=data.get("attention_mask", None),
- visual_tokens=pixel_values,
- )
- data["pixel_values"] = compressed
- else:
- data['pixel_values'] = pixel_values # shape: [1, patch_num, 3584] # shape: [1, 576, 4096]
-
- # remove coords
- data.pop('coords', None)
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- @staticmethod
- def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
- with torch.no_grad():
- x = coords_xy[:, 0]
- y = coords_xy[:, 1]
- x_for_unique = x
- y_for_unique = y
- if x_for_unique.dtype.is_floating_point:
- x_for_unique = x_for_unique.round().to(torch.int)
- y_for_unique = y_for_unique.round().to(torch.int)
- x_sorted = torch.unique(x_for_unique, sorted=True)
- y_sorted = torch.unique(y_for_unique, sorted = True)
-
- col = torch.searchsorted(x_sorted, x)
- row = torch.searchsorted(y_sorted, y)
- return torch.stack([row, col], dim=-1)
-
- def _debug_first_batch(self, data):
- print("\n" + "="*60)
- print("DEBUGGING FIRST BATCH (Sample 0)")
- print("="*60)
-
- # 0) 基本取样
- input_ids = data['input_ids'][0].tolist()
- labels = data.get('labels', None)
- labels = labels[0].tolist() if labels is not None else None
- attn = data.get('attention_mask', None)
- attn = attn[0].tolist() if attn is not None else None
-
- # 1) decode input_ids(仅用来“看字”,不过滤/改动原序列)
- # 对于 SentencePiece/BPE,直接 decode 可能会清理空格导致肉眼对齐偏差
- # 因此建议同时看 tokens,以便逐 token 对齐核查
- toks = self.tokenizer.convert_ids_to_tokens([t if t >= 0 else self.tokenizer.pad_token_id for t in input_ids])
- print(f"\n--- TOKENS (len={len(toks)}) ---")
- print(toks)
-
- decoded_input = self.tokenizer.decode([t for t in input_ids if t >= 0],
- skip_special_tokens=False,
- clean_up_tokenization_spaces=False)
- print(f"\n--- DECODED INPUT (raw, keep specials) ---\n{decoded_input}")
-
- # 2) 处理 labels 的可视化
- if labels is not None:
- print(f"\n--- RAW LABELS (len={len(labels)}) ---\n{labels}")
-
- # -100: ignore_index;-200: 你自定义的图像占位
- # 为了解码+可视化,用 pad 替换 -100,并过滤掉 -200;但保留“序号对齐”所需的长度信息
- if self.tokenizer.pad_token_id is None:
- # 很多 LLaMA 类 tokenizer 无 pad;确保 pad 存在,避免 .decode 出错
- # 常见做法:把 eos 作为 pad(训练/调试阶段)
- print('no padding')
- self.tokenizer.pad_token = self.tokenizer.eos_token
-
- labels_for_decoding = [ (self.tokenizer.pad_token_id if t == -100 else t) for t in labels ]
- labels_text_only = [ t for t in labels_for_decoding if t != -200 and t >= 0 ]
-
- decoded_labels = self.tokenizer.decode(labels_text_only,
- skip_special_tokens=True,
- clean_up_tokenization_spaces=False)
- print(f"\n--- DECODED LABELS (text that contributes to loss) ---\n{decoded_labels}")
-
- # 3) 可视化“真正参加 loss 的 token 位置”(shift 后)
- # loss 对应 shift_labels=labels[:,1:]
- loss_pos = []
- for i in range(1, len(labels)):
- if labels[i] != -100: # shift_labels 位置
- loss_pos.append(i)
-
- print(f"\n--- #LOSS TOKENS (after shift) = {len(loss_pos)} ---")
- # 打印若干窗口,帮你核对每个 loss 位周边上下文
- window = 6
- rows = []
- for idx in loss_pos[:50]: # 只看前 50 个,避免刷屏
- left = max(0, idx - window)
- right = min(len(input_ids), idx + window + 1)
- seg_ids = input_ids[left:right]
- seg_toks = self.tokenizer.convert_ids_to_tokens(
- [t if t >= 0 else self.tokenizer.pad_token_id for t in seg_ids]
- )
- caret = " " * (idx - left) + "^"
- rows.append({
- "range": f"[{left}:{right})",
- "tokens": seg_toks,
- "label_id": labels[idx]
- })
- print(f"\nrange {left}:{right}")
- print(seg_toks)
- print(caret + " (this token's label participates in loss)")
- print(f"label_id={labels[idx]}")
- if not rows:
- print("\n(No tokens are contributing to loss; check label construction.)")
-
- # 4) 校验 attention_mask 与 labels 的一致性(非必须,但有用)
- if attn is not None and len(attn) == len(labels):
- bad = sum(1 for i in range(len(labels)) if labels[i] != -100 and attn[i] == 0)
- if bad > 0:
- print(f"\n[WARN] Found {bad} tokens with label != -100 but attention_mask == 0.")
- else:
- print("\n[INFO] No attention_mask provided or length mismatch; skipping attention checks.")
-
- print("\n" + "="*60 + "\n")
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- """
- 计算 token-level 交叉熵损失(分布式/AMP 兼容)。
- - labels 中 -100 为 ignore_index
- - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置
- """
-
- # 1) 若无 labels,退回 HF 默认
- if "labels" not in data:
- outputs = self.llm(**data)
- return {"loss": outputs.loss}
-
- labels = data["labels"] # [B, T]
- input_ids = data.get("input_ids", None) # [B, T] or None
- attn = data.get("attention_mask", None) # 可无
-
- # 2) 标签清洗(不改原 labels)
- safe_labels = labels.clone()
-
- # 2.1 屏蔽负 ID(如 -200 图像占位)
- if input_ids is not None:
- neg_mask = (input_ids < 0)
- if neg_mask.any():
- safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等)
- if getattr(self, "tokenizer", None) is not None:
- try:
- special_ids = set(self.tokenizer.all_special_ids or [])
- except Exception:
- special_ids = set()
- if special_ids:
- special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- for sid in special_ids:
- special_mask |= (input_ids == sid)
- if special_mask.any():
- safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean)
- model_inputs = {k: v for k, v in data.items() if k != "labels"}
- outputs = self.llm(**model_inputs, use_cache=False)
- logits = outputs.logits # [B, T, V]
-
- # 形状断言
- if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]:
- raise RuntimeError(
- f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
- )
-
- # 4) CausalLM 对齐
- shift_logits = logits[:, :-1, :].contiguous()
- shift_labels = safe_labels[:, 1:].contiguous()
-
- # 5) 统计有效 token & 分布式聚合
- n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- world_size = 1
- n_tok_global = n_tok_local
- if dist.is_available() and dist.is_initialized():
- world_size = dist.get_world_size()
- with torch.no_grad():
- n_tok_global = n_tok_local.clone()
- dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # 若全局无监督 token,则返回 0(防 NaN)
- if n_tok_global.item() == 0:
- zero = shift_logits.sum() * 0.0
- return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
-
- # 6) 分子(sum over tokens,FP32 更稳)
- loss_sum_local = F.cross_entropy(
- shift_logits.float().view(-1, shift_logits.size(-1)),
- shift_labels.view(-1),
- ignore_index=-100,
- reduction="sum",
- )
-
- # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均)
- denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
- loss = (loss_sum_local / denom) * float(world_size)
-
- # 8) 返回
- ntok_tensor = denom.detach()
- return {"loss": loss, "ntok": ntok_tensor}
-
- # def compute_loss(self, data, data_samples=None):
- # outputs = self.llm(**data)
- # # outputs.logits.shape (1, 1094, 152064) for Qwen
- # loss_dict = {'loss': outputs.loss}
- # return loss_dict
- # === MOD: token-averaged, globally weighted loss (robust to variable lengths)
- # def compute_loss(self, data, data_samples=None):
- # """
- # 计算 token-level CE loss(分布式/AMP 兼容),在 loss 前做系统性检查与清洗:
- # - 将负 ID(如 -200 图像占位)/ special_id 对应位置统一置为 -100
- # - 可选的调试打印:每 global_step 仅打印一次,帮助确认首个受监督 token 是否对齐回答开头
- # 约定:
- # - labels 中 -100 为 ignore_index
- # - data 可能包含 input_ids / attention_mask
- # """
-
- # debug = bool(getattr(self, "debug_loss_checks", False))
-
- # # -----------------------------
- # # 1) 若无 labels,退回 HF 默认
- # # -----------------------------
- # if "labels" not in data:
- # outputs = self.llm(**data)
- # return {"loss": outputs.loss}
-
- # labels = data["labels"] # [B, T]
- # input_ids = data.get("input_ids", None) # [B, T] or None
- # attn = data.get("attention_mask", None) # [B, T] or None
-
- # # -----------------------------
- # # 2) loss 前清洗 + 轻量检查
- # # -----------------------------
- # # 2.1 pad_token 兜底(便于 decode 调试;不影响训练)
- # if getattr(self, "tokenizer", None) is not None and self.tokenizer.pad_token_id is None:
- # try:
- # self.tokenizer.pad_token = self.tokenizer.eos_token
- # except Exception:
- # pass
-
- # # 2.2 拷贝一份 labels,避免原地改动
- # safe_labels = labels.clone()
-
- # # 2.3 屏蔽负 ID(如 -200 图像占位)与 special_ids
- # if input_ids is not None:
- # # 负 ID
- # neg_mask = (input_ids < 0)
- # if neg_mask.any():
- # safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # special ids(模板标记等)
- # if getattr(self, "tokenizer", None) is not None:
- # try:
- # special_ids = set(self.tokenizer.all_special_ids or [])
- # except Exception:
- # special_ids = set()
- # if special_ids:
- # special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- # for sid in special_ids:
- # special_mask |= (input_ids == sid)
- # if special_mask.any():
- # safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # 2.4 attention_mask 与 labels 的一致性(仅提示)
- # if attn is not None:
- # if attn.shape != labels.shape:
- # if debug:
- # print("[INFO] attention_mask shape != labels shape; skip attention checks.")
- # else:
- # bad = ((safe_labels != -100) & (attn == 0)).sum().item()
- # if debug and bad > 0:
- # print(f"[WARN] {bad} supervised tokens have attention_mask=0 (context invisible).")
-
- # # -----------------------------
- # # 3) 前向,拿 logits
- # # -----------------------------
- # model_inputs = {k: v for k, v in data.items() if k != "labels"}
- # outputs = self.llm(**model_inputs, use_cache=False)
- # logits = outputs.logits # [B, T, V]
-
- # # 形状断言
- # if logits.dim() != 3:
- # raise RuntimeError(f"logits should be [B,T,V], got {tuple(logits.shape)}")
- # if logits.shape[:2] != safe_labels.shape[:2]:
- # raise RuntimeError(
- # f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
- # )
-
- # B, T, V = logits.shape
-
- # # -----------------------------
- # # 4) CausalLM 对齐
- # # -----------------------------
- # shift_logits = logits[:, :-1, :].contiguous()
- # shift_labels = safe_labels[:, 1:].contiguous()
-
- # # -----------------------------
- # # 5) 统计有效 token & 分布式聚合
- # # -----------------------------
- # n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- # world_size = 1
- # n_tok_global = n_tok_local
- # if dist.is_available() and dist.is_initialized():
- # world_size = dist.get_world_size()
- # with torch.no_grad():
- # n_tok_global = n_tok_local.clone()
- # dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # # 若全局无监督 token,则返回 0 loss(防 NaN)
- # if n_tok_global.item() == 0:
- # if debug:
- # print("[INFO] No supervised tokens globally; returning zero loss.")
- # zero = shift_logits.sum() * 0.0 # 保持 dtype/device
- # return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
-
- # # -----------------------------
- # # 6) 交叉熵分子(sum over tokens, FP32)
- # # -----------------------------
- # loss_sum_local = F.cross_entropy(
- # shift_logits.float().view(-1, V),
- # shift_labels.view(-1),
- # ignore_index=-100,
- # reduction="sum",
- # )
-
- # # -----------------------------
- # # 7) 归一化 & 数值守卫
- # # -----------------------------
- # denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
- # loss = (loss_sum_local / denom) * float(world_size)
-
- # if torch.isnan(loss) or torch.isinf(loss):
- # raise FloatingPointError(
- # f"Loss is NaN/Inf. loss_sum_local={loss_sum_local.item()}, n_tok_global={n_tok_global.item()}, world_size={world_size}"
- # )
-
- # # -----------------------------
- # # 8) (可选)每 step 打印一次的预检
- # # -----------------------------
- # if debug:
- # # 仅本 step 打印一次
- # if not hasattr(self, "_loss_precheck_printed_step"):
- # self._loss_precheck_printed_step = -1
- # cur_step = int(getattr(self, "global_step", 0))
- # if self._loss_precheck_printed_step != cur_step:
- # self._loss_precheck_printed_step = cur_step
-
- # # 首/末个受监督位置(after shift)
- # first_pos, last_pos, n_valid = [], [], []
- # for b in range(B):
- # mask = (shift_labels[b] != -100)
- # idxs = torch.nonzero(mask, as_tuple=False).flatten()
- # n_valid.append(int(mask.sum().item()))
- # first_pos.append(int(idxs[0].item()) if idxs.numel() else None)
- # last_pos.append(int(idxs[-1].item()) if idxs.numel() else None)
-
- # print(f"[LOSS-PRECHECK] B={B}, T={T}, V={V}")
- # for b in range(min(B, 2)): # 只打印前两个样本
- # print(f" sample {b}: n_valid={n_valid[b]}, first={first_pos[b]}, last={last_pos[b]}")
-
- # # 打印 sample0 上下文切片,确认 first 是否对齐回答开头
- # if input_ids is not None and getattr(self, "tokenizer", None) is not None:
- # fp = first_pos[0]
- # if fp is not None:
- # i = fp + 1 # shift 后位置 -> 原序列位置
- # L = max(0, i - 12)
- # R = min(T, i + 12)
- # pad_id = self.tokenizer.pad_token_id or getattr(self.tokenizer, "eos_token_id", None) or 0
- # toks = self.tokenizer.convert_ids_to_tokens(
- # [int(t if t >= 0 else pad_id) for t in input_ids[0, L:R].tolist()]
- # )
- # caret = " " * (i - L) + "^"
- # print(f" context[0] {L}:{R} tokens:\n {toks}\n {caret}")
- # try:
- # seg_ids = [int(x) for x in input_ids[0, L:R].tolist() if x >= 0]
- # seg = self.tokenizer.decode(seg_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
- # print(f" decoded segment:\n {seg}")
- # except Exception as e:
- # print(f" [decode-skip] {e}")
-
- # # 统计负 ID 屏蔽数
- # n_neg = int((input_ids[0] < 0).sum().item())
- # print(f" masked_by_negative_ids(sample0): {n_neg}")
-
- # # attention 额外统计
- # if attn is not None and attn.shape == labels.shape:
- # bad = int(((safe_labels != -100) & (attn == 0)).sum().item())
- # print(f" supervised_but_attn0(global): {bad}")
-
- # # -----------------------------
- # # 9) 返回
- # # -----------------------------
- # ntok_tensor = denom.detach() # float 标量张量即可
- # return {"loss": loss, "ntok": ntok_tensor}
-
- # def compute_loss(self, data, data_samples=None):
- # # 1) 若无 labels,退回 HF 默认
- # if 'labels' not in data:
- # outputs = self.llm(**data)
- # return {'loss': outputs.loss}
-
-
-
- # labels = data['labels'] # [B, T]
- # # 不把 labels 交给 HF,避免其先做 per-device mean
- # model_inputs = {k: v for k, v in data.items() if k != 'labels'}
-
- # outputs = self.llm(**model_inputs, use_cache=False)
- # logits = outputs.logits # [B, T, V]
-
- # # 2) CausalLM 对齐
- # shift_logits = logits[:, :-1, :].contiguous()
- # shift_labels = labels[:, 1:].contiguous()
-
- # # 3) 本卡有效 token 数(忽略 -100)
- # n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
-
- # # 4) 分子:sum over tokens(用 FP32 计算更稳)
- # loss_sum_local = F.cross_entropy(
- # shift_logits.float().view(-1, shift_logits.size(-1)),
- # shift_labels.view(-1),
- # ignore_index=-100,
- # reduction='sum'
- # )
-
- # # 5) 计算全局分母;不要让反传穿过 collective(用 no_grad + clone)
- # world_size = 1
- # n_tok_global = n_tok_local
- # if dist.is_available() and dist.is_initialized():
- # world_size = dist.get_world_size()
- # with torch.no_grad():
- # n_tok_global = n_tok_local.clone()
- # dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
-
- # # 6) 构造最终 loss:
- # # 用“本卡分子 / 全局分母”,再乘 world_size 抵消 DDP 的梯度平均,
- # # 这样反向后的等效梯度就是“全局 token 平均”的梯度。
- # loss = (loss_sum_local / denom) * float(world_size)
-
- # # 7) 记录指标:把 ntok 作为张量返回,避免 parse_losses 报错
- # ntok_tensor = denom.detach() # float 标量张量即可
-
- # return {
- # 'loss': loss,
- # 'ntok': ntok_tensor
- # }
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
- if self.use_perceiver_resampler:
-
- perceiver_path = osp.join(save_dir, "perceiver")
- print_log(f'Saving LongNet_encoder to {perceiver_path}', 'current')
- os.makedirs(perceiver_path, exist_ok=True)
- perceiver_output_path = os.path.join(perceiver_path, 'perceiver.safetensors')
- save_file(self.perceiver.state_dict(), perceiver_output_path)
-
- # LongNet_encoder
- if self.LongNet_encoder is not None:
- LongNet_encoder_path = osp.join(save_dir, 'LongNet_encoder')
- print_log(f'Saving LongNet_encoder to {LongNet_encoder_path}', 'current')
- # Ensure the target directory exists
- os.makedirs(LongNet_encoder_path, exist_ok=True)
-
- # Define the full path for the weights file
- output_path = osp.join(LongNet_encoder_path, 'longnet_encoder.safetensors')
-
- # Save the state dictionary using safetensors
- save_file(self.LongNet_encoder.state_dict(), output_path)
-
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_no_longnet.py b/code/xtuner/model/llava_no_longnet.py
deleted file mode 100644
index 799c7ed54a1c6f2b8d95c6ac7d0ee9fb1af08c4e..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_no_longnet.py
+++ /dev/null
@@ -1,1257 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-
-import torch
-import torch.nn as nn
-import torch.distributed as dist # === MOD ===
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-import os
-from safetensors.torch import load_file, save_file
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-import torch.nn.functional as F
-from .sparse_token_merge import SparsePatchMerging
-from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed
-from peft import PeftModel
-from peft.tuners.lora.layer import LoraLayer
-
-# ===== 在类前或类内其它位置都可以:新增一个探测函数 =====
-def _detect_qwen_major_version(llm) -> int:
- """
- 返回 3 表示 Qwen3,2 表示 Qwen2,0 表示未知/其它。
- 优先用 config.model_type,其次回退到类名字符串。
- """
- base = llm.model if hasattr(llm, "model") else llm
- cfg = getattr(base, "config", None)
- mt = (getattr(cfg, "model_type", None) or "").lower()
- if mt == "qwen3":
- return 3
- if mt == "qwen2":
- return 2
-
- # 回退:根据类名判别
- cname = base.__class__.__name__.lower()
- if "qwen3" in cname:
- return 3
- if "qwen2" in cname:
- return 2
- return 0
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
-
- # slide/pos-embed 参数
- slide_ngrids=1000,
- pe_gate_value=1.0,
- pe_dropout=0.1,
- tile_size=224,
-
- # 各子模块权重路径
- projector_pth=None,
- perceiver_pth=None,
- token_merge_pth=None,
- pe_gate_pth=None,
-
- # Token Merge
- enable_token_merge=True,
-
- # Perceiver Resampler 配置
- use_perceiver_resampler=True,
- concat_text_to_queries=True,
- perceiver_num_latents=64,
- perceiver_depth=2,
-
- # === 新增:Stage-2 冻结选项 ===
- freeze_mm_in_stage2=False, # 总开关:在 stage-2 冻结 projector / perceiver / token_merge
- freeze_projector_stage2=None, # 子开关(None 表示跟随总开关)
- freeze_perceiver_stage2=None, # 子开关(None 表示跟随总开关)
- freeze_token_merge_stage2=None # 子开关(None 表示跟随总开关)
- ):
- super().__init__()
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- self.tile_size = tile_size
-
- # 训练阶段控制
- if train_stage == '0':
- print_log('train_stage == 0', 'current')
- self.freeze_llm = True
- if train_stage == '1':
- print_log('train_stage == 1', 'current')
- self.freeze_llm = True
- elif train_stage == '2':
- print_log('train_stage == 2', 'current')
- self.freeze_llm = False
-
- # 解析 stage-2 的冻结意图
- def _resolve(flag):
- return freeze_mm_in_stage2 if flag is None else bool(flag)
- self._freeze_projector_in_s2 = _resolve(freeze_projector_stage2)
- self._freeze_perceiver_in_s2 = _resolve(freeze_perceiver_stage2)
- self._freeze_token_merge_in_s2 = _resolve(freeze_token_merge_stage2)
-
- # 构建 / 派发 LLM
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- # Token Merge
- self.enable_token_merge = enable_token_merge
- if self.enable_token_merge:
- self.token_merge = SparsePatchMerging(
- embed_dim=hidden_size,
- layernorm_eps=1e-6,
- merge_size=2
- )
-
- # Projector
- self.projector_depth = projector_depth
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size * 4 if self.enable_token_merge else hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth
- )
- self.projector = ProjectorModel(projector_config).to(self.llm.dtype)
- self.projector.requires_grad_(True)
-
- # Perceiver Resampler
- self.use_perceiver_resampler = use_perceiver_resampler
- self.slide_ngrids = slide_ngrids
- if self.use_perceiver_resampler:
- self.perceiver_num_latents = perceiver_num_latents
- self.perceiver_depth = perceiver_depth
-
- num_patches = slide_ngrids ** 2
- self.pe_gate = nn.Parameter(torch.tensor(pe_gate_value, dtype=self.llm.dtype))
- self.pe_drop = nn.Dropout(pe_dropout)
- self.register_buffer(
- 'pos_embed',
- torch.zeros(1, num_patches, self.llm.config.hidden_size),
- persistent=False
- )
-
- # 自动选择 Qwen2 / Qwen3 的 Perceiver 实现
- qwen_major = _detect_qwen_major_version(self.llm)
- print_log(f'using qwen version {qwen_major}', 'current')
- if qwen_major == 3:
- try:
- from .qwen3_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- print_log('using qwen3', 'current')
- except Exception as e:
- raise RuntimeError(
- "检测到 Qwen3,但未找到 qwen3_perceiver_resampler,请确认文件存在且 transformers 版本满足要求(>=4.51)。"
- ) from e
- elif qwen_major == 2:
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
- else:
- warnings.warn(
- "未能确定 Qwen 主版本(既不是 qwen3 也不是 qwen2)。将回退到 Qwen2 的 Perceiver 实现。",
- RuntimeWarning,
- )
- from .qwen2_perceiver_resampler import (
- PerceiverResampler as _PR,
- init_perceiver_from_llm_auto as _init_pr,
- )
-
- if concat_text_to_queries:
- print_log("concat text to queries in perceiver", 'current')
-
- self.perceiver = _PR(
- self.llm,
- num_latents=self.perceiver_num_latents,
- depth=self.perceiver_depth,
- concat_text_to_queries=concat_text_to_queries,
- ).to(self.llm.dtype)
-
- # 仅当没有提供 perceiver_pth 或路径不存在时,才尝试从 LLM 自动初始化
- if perceiver_pth is None or not os.path.exists(perceiver_pth):
- _init_pr(
- perceiver=self.perceiver,
- llm=self.llm,
- ckpt_hint=getattr(self.llm.config, "_name_or_path", None),
- init_from_layers=self.perceiver.depth,
- layer_offset=0,
- allow_download=False,
- )
-
- # 初始化 pos-embed 等
- self.initialize_pe_weights()
-
- # 冻结 LLM
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- # 激活检查点(按需对冻结模块跳过 input-grad 使能)
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- if self.use_perceiver_resampler:
- _perceiver_frozen = (train_stage == '2' and self._freeze_perceiver_in_s2)
- if not _perceiver_frozen:
- self.perceiver.enable_input_require_grads()
- else:
- print_log('[stage-2] Skipping perceiver.enable_input_require_grads() (frozen)', 'current')
-
- _projector_frozen = (train_stage == '2' and self._freeze_projector_in_s2)
- if not _projector_frozen:
- print('enable projector input require grads')
- print_log('enable projector input require grads', 'current')
- self.projector.enable_input_require_grads()
- else:
- print_log('[stage-2] Skipping projector.enable_input_require_grads() (frozen)', 'current')
-
- # 启用激活检查点
- self.gradient_checkpointing_enable()
-
- # LoRA
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = None
- if self.use_llm_lora:
- print_log(f"Building lora {llm_lora.__str__}", "current")
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
- self.verify_lora()
-
- # 加载 token_merge / projector / perceiver / pe_gate 的 safetensors
- if token_merge_pth is not None and enable_token_merge and hasattr(self, 'token_merge'):
- print_log(f'loading token_merge from {token_merge_pth}', 'current')
- merger_sd = load_file(token_merge_pth, device='cpu')
- self.token_merge.load_state_dict(merger_sd, strict=False)
- self.token_merge.to(self.llm.dtype)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if perceiver_pth is not None and self.use_perceiver_resampler and hasattr(self, 'perceiver'):
- print_log(f'Loading perceiver from {perceiver_pth}', 'current')
- perceiver_sd = load_file(perceiver_pth, device="cpu")
- self.perceiver.load_state_dict(perceiver_sd, strict=False)
- self.perceiver.to(self.llm.dtype)
-
- if pe_gate_pth is not None and self.use_perceiver_resampler and hasattr(self, "pe_gate"):
- print_log(f'Loading pe_gate from {pe_gate_pth}', 'current')
- sd = load_file(pe_gate_pth, device="cpu")
- if "pe_gate" not in sd:
- raise KeyError(f"'pe_gate' not found in {pe_gate_pth}. Keys: {list(sd.keys())}")
- with torch.no_grad():
- self.pe_gate.copy_(sd["pe_gate"].to(dtype=self.llm.dtype, device=self.pe_gate.device))
-
- # 额外加载 float 权重(可选)
- if pretrained_pth is not None:
- sd = guess_load_checkpoint(pretrained_pth)
- model_sd = self.state_dict()
- filtered = {k: v for k, v in sd.items() if k in model_sd and model_sd[k].shape == v.shape}
- missing, unexpected = self.load_state_dict(filtered, strict=False)
- print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
- print_log(f" missing: {missing}", "current")
- print_log(f" unexpected:{unexpected}", "current")
-
- # 记录可视层
- self.visual_select_layer = visual_select_layer
-
- # 初始化标志
- self._is_init = True
- self.is_first_iter = True
-
- # === 关键新增:在 Stage-2 按需冻结三个多模态子模块 ===
- if train_stage == '2':
- # projector
- if hasattr(self, 'projector') and self._freeze_projector_in_s2:
- self.projector.requires_grad_(False)
- self.projector.eval()
- print_log('[stage-2] Freezing projector parameters', 'current')
-
- # perceiver(含 pe_gate)
- if getattr(self, 'use_perceiver_resampler', False) and hasattr(self, 'perceiver') and self._freeze_perceiver_in_s2:
- self.perceiver.requires_grad_(False)
- self.perceiver.eval()
- print_log('[stage-2] Freezing perceiver parameters', 'current')
- if hasattr(self, 'pe_gate') and self._freeze_perceiver_in_s2:
- self.pe_gate.requires_grad = False
-
- # token_merge
- if getattr(self, 'enable_token_merge', False) and hasattr(self, 'token_merge') and self._freeze_token_merge_in_s2:
- self.token_merge.requires_grad_(False)
- self.token_merge.eval()
- print_log('[stage-2] Freezing token_merge parameters', 'current')
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- # def initialize_pe_weights(self):
- # # initialization
- # # initialize (and freeze) pos_embed by sin-cos embedding
- # if self.use_perceiver_resampler:
- # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=False)
- # self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
-
- @torch.no_grad()
- def initialize_pe_weights(self, chunk_rows: int = 64, chunk_cols: int = 64):
- """
- 在 GPU 上用 float64 精度生成 2D sin-cos 位置编码,
- 逻辑与 numpy 版本完全一致,然后写入 self.pos_embed。
- """
- if not getattr(self, "use_perceiver_resampler", False):
- return
-
- H = W = int(self.slide_ngrids)
- D = int(self.llm.config.hidden_size)
- assert D % 4 == 0, "hidden_size 必须是 4 的倍数,才能和 numpy 实现严格对应。"
-
- device = self.pos_embed.device
- dtype64 = torch.float64 # 全程用 float64
-
- # 预分配/调整 buffer 形状
- if self.pos_embed.shape != (1, H * W, D):
- self.pos_embed.resize_(1, H * W, D)
-
- pos4d = self.pos_embed.view(1, H, W, D)
-
- # 频率向量
- k = D // 4
- inv = 1.0 / (10000 ** (torch.arange(k, device=device, dtype=dtype64) / k))
-
- # 整数坐标 (与 numpy 一致)
- y_lin = torch.arange(H, device=device, dtype=dtype64)
- x_lin = torch.arange(W, device=device, dtype=dtype64)
-
- # 一维编码
- y_phase = y_lin.unsqueeze(1) * inv.unsqueeze(0) # [H,k]
- x_phase = x_lin.unsqueeze(1) * inv.unsqueeze(0) # [W,k]
- y_enc = torch.cat([torch.sin(y_phase), torch.cos(y_phase)], dim=1) # [H,2k]
- x_enc = torch.cat([torch.sin(x_phase), torch.cos(x_phase)], dim=1) # [W,2k]
-
- # 分块写入,避免一次性大张量
- for r0 in range(0, H, chunk_rows):
- r1 = min(r0 + chunk_rows, H)
- R = r1 - r0
- y_chunk = y_enc[r0:r1].unsqueeze(1) # [R,1,2k]
-
- for c0 in range(0, W, chunk_cols):
- c1 = min(c0 + chunk_cols, W)
- C = c1 - c0
- x_chunk = x_enc[c0:c1].unsqueeze(0) # [1,C,2k]
-
- # 拼接顺序与 numpy 一致: [emb_w, emb_h]
- emb_rc = torch.cat(
- [x_chunk.expand(R, C, 2*k),
- y_chunk.expand(R, C, 2*k)],
- dim=2
- ) # [R,C,D]
-
- # copy 到 buffer(自动 cast 到 buffer dtype)
- pos4d[0, r0:r1, c0:c1, :].copy_(emb_rc.to(pos4d.dtype))
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- # we use xavier_uniform following official JAX ViT:
- torch.nn.init.xavier_uniform_(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def verify_lora(self):
- m = self.llm
-
- # 1) Wrapped as a PEFT model
- assert isinstance(m, PeftModel), "LoRA not applied: model is not a PeftModel"
-
- # 2) Adapters are registered and active
- adapters = m.peft_config # dict: {adapter_name: LoraConfig}
- assert len(adapters) > 0, "No adapters registered in peft_config"
- active = m.active_adapter if hasattr(m, "active_adapter") else None
- assert active in adapters, f"Active adapter {active} not found in peft_config"
-
- # 3) LoRA layers are present on target modules
- lora_modules = [mod for mod in m.modules() if isinstance(mod, LoraLayer)]
- assert len(lora_modules) > 0, "No LoraLayer modules found (check target_modules)"
-
- # 4) LoRA params are the only trainable ones (typical for QLoRA)
- trainable = [(n,p) for n,p in m.named_parameters() if p.requires_grad]
- assert len(trainable) > 0, "No trainable parameters (LoRA params are not set to requires_grad=True)"
- # Optional: sanity-check that trainable params look like LoRA
- suspicious = [n for n,_ in trainable if "lora_" not in n and "modules_to_save" not in n]
- # It's okay if you intentionally left some modules_to_save; adjust as needed.
- assert len(suspicious) == 0, f"Unexpected trainable params (not LoRA): {suspicious[:5]}"
-
- # 5) Quick count + readable log
- total = sum(p.numel() for _,p in m.named_parameters())
- trainable_cnt = sum(p.numel() for _,p in trainable)
- ratio = trainable_cnt / total
- print(f"[LoRA OK] adapters={list(adapters.keys())}, active={active}, "
- f"LoraLayers={len(lora_modules)}, trainable={trainable_cnt}/{total} ({ratio:.4%})")
-
- # 6) Forward+backward smoke test to confirm gradients flow to LoRA only
- m.train()
- dummy_inp = torch.randint(0, m.get_input_embeddings().num_embeddings, (1, 8)).to(next(m.parameters()).device)
- out = m(input_ids=dummy_inp, labels=dummy_inp)
- out.loss.backward() # should not error
- # Ensure some LoRA grads exist
- lora_grads = [p.grad for _,p in m.named_parameters() if p.requires_grad and p.grad is not None]
- assert len(lora_grads) > 0, "No gradients on LoRA parameters after backward()"
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self, use_reentrant=False):
- self.activation_checkpointing_enable(use_reentrant=use_reentrant)
-
- def activation_checkpointing_enable(self, use_reentrant=False):
- # LLM
- try:
- self.llm.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except TypeError:
- # older HF versions
- self.llm.gradient_checkpointing_enable()
-
- # projector
- try:
- self.projector.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except TypeError:
- self.projector.gradient_checkpointing_enable()
-
- # perceiver (if present)
- if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None:
- try:
- self.perceiver.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except AttributeError:
- # some custom modules only expose input-grad helper
- if hasattr(self.perceiver, 'enable_input_require_grads'):
- self.perceiver.enable_input_require_grads()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- if self.use_perceiver_resampler:
- self.perceiver.disable_gradient_checkpointing()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
-
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 5. Perceiver Resampler (unchanged)
- if getattr(self, 'use_perceiver_resampler', False) and getattr(self, 'perceiver', None) is not None:
- to_return.update({k: v for k, v in state_dict.items() if 'perceiver.' in k})
-
- if getattr(self, 'pe_gate', False):
- to_return.update({k: v for k, v in state_dict.items() if 'pe_gate.' in k})
-
- # step 5 token merger
- if getattr(self, 'token_merge', False):
- to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
-
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def coords_to_pos(self, coords, tile_size: int = 224):
- """
- This function is used to convert the coordinates to the positional indices
-
- Arguments:
- ----------
- coords: torch.Tensor
- The coordinates of the patches, of shape [N, L, 2]
- output: torch.Tensor
- The positional indices of the patches, of shape [N, L]
- """
- coords_ = torch.floor(coords / tile_size)
- pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
- return pos.long() # add 1 for the cls token
-
- @staticmethod
- def _coords_rc_to_pos(coords_rc: torch.Tensor, ngrids: int) -> torch.Tensor:
- if coords_rc.dtype.is_floating_point:
- coords_rc = coords_rc.round().to(torch.long)
- # row = coords_rc[:, 0].clamp_(0, ngrids-1)
- # col = coords_rc[:, 1].clamp_(0, ngrids-1)
- return (coords_rc[..., 0] * ngrids + coords_rc[..., 1]).long() # +1 for cls
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
- coords = None
-
- if 'pixel_values' in data:
-
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
-
- if 'coords' in data:
- coords = data['coords'].to(self.llm.dtype)
- # Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor
- coords_t = coords[0] if isinstance(coords, list) else coords
- Bx = feat_to_proj.size(0) # actual batch size of inputs
- if not torch.is_tensor(coords_t):
- raise ValueError("coords must be a Tensor or list[Tensor].")
-
- if coords_t.dim() == 2:
- # [L, 2]
- coords_rc = coords_t
- elif coords_t.dim() == 3:
- # [B, L, 2] -> ensure B matches and either B==1 or all examples share coords
- if coords_t.size(0) != Bx:
- raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
- if Bx == 1:
- coords_rc = coords_t[0]
- else:
- # require same coords across the batch (cheap equality check)
- if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
- raise NotImplementedError(
- "Per-example coords (varying across batch) are not supported by the current "
- "patch-merging/layout path. Use batch size 1 or share coords across the batch."
- )
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
-
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
- else:
- raise RuntimeError
-
- # only works for batch size one
- if self.enable_token_merge:
- feat_to_proj, coords_rc_merged, _ = self.token_merge(
- x=feat_to_proj,
- coords_rc=self._coords_to_rowcol(coords_rc), # 你已有,生成 rc
- padmask=torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
- device=feat_to_proj.device, dtype=torch.bool)
- )
- else:
- coords_rc_merged = self._coords_to_rowcol(coords_rc)
- padmask_merged = torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
- device=feat_to_proj.device, dtype=torch.bool)
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, 3584]
-
- if self.use_perceiver_resampler and 'input_ids' in data:
- text_emb = self.llm.get_input_embeddings()(data["input_ids"].clamp(min=0)) \
- .to(self.llm.dtype).detach()
- # 注意:这里的 coords_RC 已经是合并后的 (row, col)
- # print(coords_rc_merged.max(), coords_rc_merged.shape)
- pos = self._coords_rc_to_pos(coords_rc_merged, self.slide_ngrids) # B==1 假设
- # print(pos.max(), pos.shape)
- pixel_values = pixel_values + self.pe_drop(self.pos_embed[:, pos, :].squeeze(0) * self.pe_gate)
-
- compressed = self.perceiver(
- # input_ids = data["input_ids"],
- text_embeddings=text_emb,
- attention_mask=data.get("attention_mask", None),
- visual_tokens=pixel_values,
- )
- data["pixel_values"] = compressed
- else:
- data['pixel_values'] = pixel_values # shape: [1, patch_num, 3584] # shape: [1, 576, 4096]
-
- # remove coords
- data.pop('coords', None)
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- @staticmethod
- def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
- with torch.no_grad():
- x = coords_xy[:, 0]
- y = coords_xy[:, 1]
- x_for_unique = x
- y_for_unique = y
- if x_for_unique.dtype.is_floating_point:
- x_for_unique = x_for_unique.round().to(torch.int)
- y_for_unique = y_for_unique.round().to(torch.int)
- x_sorted = torch.unique(x_for_unique, sorted=True)
- y_sorted = torch.unique(y_for_unique, sorted = True)
-
- col = torch.searchsorted(x_sorted, x)
- row = torch.searchsorted(y_sorted, y)
- return torch.stack([row, col], dim=-1)
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- # def compute_loss(self, data, data_samples=None):
- # outputs = self.llm(**data)
- # # outputs.logits.shape (1, 1094, 152064) for Qwen
- # loss_dict = {'loss': outputs.loss}
- # return loss_dict
-
-
- # # === MOD: token-averaged, globally weighted loss (robust to variable lengths)
- def compute_loss(self, data, data_samples=None):
- """
- 计算 token-level 交叉熵损失(分布式/AMP 兼容)。
- - labels 中 -100 为 ignore_index
- - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置
- """
-
- # 1) 若无 labels,退回 HF 默认
- if "labels" not in data:
- outputs = self.llm(**data)
- return {"loss": outputs.loss}
-
- labels = data["labels"] # [B, T]
- input_ids = data.get("input_ids", None) # [B, T] or None
- attn = data.get("attention_mask", None) # 可无
-
- # 2) 标签清洗(不改原 labels)
- safe_labels = labels.clone()
-
- # 2.1 屏蔽负 ID(如 -200 图像占位)
- if input_ids is not None:
- neg_mask = (input_ids < 0)
- if neg_mask.any():
- safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等)
- if getattr(self, "tokenizer", None) is not None:
- try:
- special_ids = set(self.tokenizer.all_special_ids or [])
- except Exception:
- special_ids = set()
- if special_ids:
- special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- for sid in special_ids:
- special_mask |= (input_ids == sid)
- if special_mask.any():
- safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean)
- model_inputs = {k: v for k, v in data.items() if k != "labels"}
- outputs = self.llm(**model_inputs, use_cache=False)
- logits = outputs.logits # [B, T, V]
-
- # 形状断言
- if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]:
- raise RuntimeError(
- f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
- )
-
- # 4) CausalLM 对齐
- shift_logits = logits[:, :-1, :].contiguous()
- shift_labels = safe_labels[:, 1:].contiguous()
-
- # 5) 统计有效 token & 分布式聚合
- n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- world_size = 1
- n_tok_global = n_tok_local
- if dist.is_available() and dist.is_initialized():
- world_size = dist.get_world_size()
- with torch.no_grad():
- n_tok_global = n_tok_local.clone()
- dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # 若全局无监督 token,则返回 0(防 NaN)
- if n_tok_global.item() == 0:
- zero = shift_logits.sum() * 0.0
- return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
-
- # 6) 分子(sum over tokens,FP32 更稳)
- loss_sum_local = F.cross_entropy(
- shift_logits.float().view(-1, shift_logits.size(-1)),
- shift_labels.view(-1),
- ignore_index=-100,
- reduction="sum",
- )
-
- # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均)
- denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
- loss = (loss_sum_local / denom) * float(world_size)
-
- # 8) 返回
- ntok_tensor = denom.detach()
- return {"loss": loss, "ntok": ntok_tensor}
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
- if self.use_perceiver_resampler and hasattr(self, 'perceiver'):
-
- perceiver_path = osp.join(save_dir, "perceiver")
- print_log(f'Saving LongNet_encoder to {perceiver_path}', 'current')
- os.makedirs(perceiver_path, exist_ok=True)
- perceiver_output_path = os.path.join(perceiver_path, 'perceiver.safetensors')
- save_file(self.perceiver.state_dict(), perceiver_output_path)
-
- if self.enable_token_merge and hasattr(self, 'token_merge'):
- merger_path = osp.join(save_dir, 'token_merger')
- print_log(f'Saving token merger to{merger_path}', 'current')
- os.makedirs(merger_path, exist_ok= True)
- merger_path = osp.join(merger_path, 'merger.safetensors')
- save_file(self.token_merge.state_dict(), merger_path)
-
- if self.use_perceiver_resampler and hasattr(self, 'pe_gate'):
- pe_gate_path = osp.join(save_dir, 'pe_gate')
- print_log(f'saving pe_gate to {pe_gate_path}', 'current')
- os.makedirs(pe_gate_path, exist_ok= True)
- pe_gate_output_path = os.path.join(pe_gate_path, 'pe_gate.safetensors')
- # choose dtype for saving
- save_dtype = torch.float32 if fp32 else self.llm.dtype
- # save as a single-tensor safetensors file
- save_file(
- {"pe_gate": self.pe_gate.detach().to(save_dtype).cpu()},
- pe_gate_output_path
- )
-
-
-
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/llava_no_longnet_simple_sampler.py b/code/xtuner/model/llava_no_longnet_simple_sampler.py
deleted file mode 100644
index 1b283cc1aded045ab4c5fc9f368e4494ddcc7e66..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_no_longnet_simple_sampler.py
+++ /dev/null
@@ -1,1249 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os
-import os.path as osp
-import warnings
-from collections import OrderedDict
-from functools import partial
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
-from peft.tuners.lora.layer import LoraLayer
-from safetensors.torch import load_file, save_file
-from torch.nn.init import trunc_normal_
-from torch.utils.checkpoint import checkpoint
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.model.torchscale.component.multihead_attention import MultiheadAttention
-from xtuner.model.torchscale.architecture.config import EncoderConfig
-
-from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .sparse_token_merge import SparsePatchMerging
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-
-# --- 辅助函数 (来自您的代码,保持不变) ---
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h)
- grid = np.stack(grid, axis=0)
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
- emb = np.concatenate([emb_h, emb_w], axis=1)
- return emb
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=np.float32)
- omega /= embed_dim / 2.
- omega = 1. / 10000**omega
- pos = pos.reshape(-1)
- out = np.einsum('m,d->md', pos, omega)
- emb_sin = np.sin(out)
- emb_cos = np.cos(out)
- emb = np.concatenate([emb_sin, emb_cos], axis=1)
- return emb
-
-# --- 修正后的 Resampler 类 ---
-class Resampler(nn.Module):
- """
- 修正后的 Resampler 版本:
- 1. 区分 query_pos_embed 和 input_pos_embed,解决变量冲突。
- 2. 解除对外部 llm 模块的依赖,提高封装性。
- 3. 修正 forward 方法中的位置编码应用逻辑和维度匹配。
- 4. 集成梯度检查点(gradient_checkpointing)功能以节省显存。
- """
- def __init__(
- self,
- grid_size,
- embed_dim,
- num_heads,
- slide_ngrids=1000, # 从外部传入网格大小
- kv_dim=None,
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
- gradient_checkpointing=False # 控制是否启用梯度检查点
- ):
- super().__init__()
- self.num_queries = grid_size ** 2
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.slide_ngrids = slide_ngrids
- self.gradient_checkpointing = gradient_checkpointing
-
- # 1. 用于 Query 的位置编码 (固定,不参与训练)
- self.query_pos_embed = nn.Parameter(
- torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float(),
- requires_grad=False
- )
-
- # 2. 用于输入视觉特征的位置编码 (大 buffer,在 GPU 上生成)
- num_patches = slide_ngrids ** 2
- self.register_buffer(
- 'input_pos_embed',
- torch.zeros(1, num_patches, embed_dim),
- persistent=False
- )
-
- # 可学习的 Query 向量
- self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
- trunc_normal_(self.query, std=.02)
-
- # KV 投影层
- if kv_dim is not None and kv_dim != embed_dim:
- self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
- else:
- self.kv_proj = nn.Identity()
-
- # 核心模块
- self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
- # args = EncoderConfig()
- # self.attn = MultiheadAttention(args =args,
- # embed_dim= embed_dim,
- # num_heads=num_heads,
- # self_attention=False,
- # encoder_decoder_attention=True,
- # )
-
- self.ln_q = norm_layer(embed_dim)
- self.ln_kv = norm_layer(embed_dim)
- self.ln_post = norm_layer(embed_dim)
- self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
-
- # 初始化权重和输入位置编码
- self.apply(self._init_weights)
- self.initialize_input_pe_weights()
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.no_grad()
- def initialize_input_pe_weights(self, chunk_rows: int = 64, chunk_cols: int = 64):
- H = W = self.slide_ngrids
- D = self.embed_dim
- assert D % 4 == 0, "embed_dim 必须是 4 的倍数,才能和 numpy 实现严格对应。"
-
- device = self.input_pos_embed.device
- dtype64 = torch.float64
-
- if self.input_pos_embed.shape != (1, H * W, D):
- self.input_pos_embed.resize_(1, H * W, D)
-
- pos4d = self.input_pos_embed.view(1, H, W, D)
-
- k = D // 4
- inv = 1.0 / (10000 ** (torch.arange(k, device=device, dtype=dtype64) / k))
-
- y_lin = torch.arange(H, device=device, dtype=dtype64)
- x_lin = torch.arange(W, device=device, dtype=dtype64)
-
- y_phase = y_lin.unsqueeze(1) * inv.unsqueeze(0)
- x_phase = x_lin.unsqueeze(1) * inv.unsqueeze(0)
- y_enc = torch.cat([torch.sin(y_phase), torch.cos(y_phase)], dim=1)
- x_enc = torch.cat([torch.sin(x_phase), torch.cos(x_phase)], dim=1)
-
- for r0 in range(0, H, chunk_rows):
- r1 = min(r0 + chunk_rows, H)
- R = r1 - r0
- y_chunk = y_enc[r0:r1].unsqueeze(1)
-
- for c0 in range(0, W, chunk_cols):
- c1 = min(c0 + chunk_cols, W)
- C = c1 - c0
- x_chunk = x_enc[c0:c1].unsqueeze(0)
- emb_rc = torch.cat([
- x_chunk.expand(R, C, 2 * k),
- y_chunk.expand(R, C, 2 * k)
- ], dim=2)
- pos4d[0, r0:r1, c0:c1, :].copy_(emb_rc.to(pos4d.dtype))
-
- def _checkpointed_forward(self, q_embed, kv_embed):
- # 封装 attention 和后续层,用于梯度检查点
- # q_embed: [num_queries, N, C], kv_embed: [L, N, C]
- # print(f"_checkpointed_forward q_embed shape: {q_embed.shape}, kv_embed shape: {kv_embed.shape}")
- attn_out = self.attn(q_embed, kv_embed, kv_embed)[0]
- permuted_out = attn_out
- ln_out = self.ln_post(permuted_out)
- proj_out = ln_out @ self.proj
- return proj_out
-
- def forward(self, x, coords_rc, attn_mask=None):
- # x shape: [N, L, C], coords_rc: [L, 2] (row, col indices)
-
- # 1. 从 buffer 中根据坐标索引,为输入 tokens 获取位置编码
- # .squeeze(0) 移除批次维度,然后进行索引
- # print(f"Resampler input x shape: {x.shape}, coords_rc shape: {coords_rc.shape}")
- pos_indices = (coords_rc[..., 0] * self.slide_ngrids + coords_rc[..., 1]).long()
- # print(f"Resampler input pos_indices shape: {pos_indices.shape}, values: {pos_indices}")
- input_pos = self.input_pos_embed[:, pos_indices, :].squeeze(0) # Shape: [L, C]
- # print(f"Resampler input_pos shape: {input_pos.shape}")
-
- # [MODIFIED] 直接在 (N, L, C) 格式上操作,不再需要 permute
- x = self.kv_proj(x)
- kv_embed = self.ln_kv(x)
-
- N = x.shape[0]
- q = self.ln_q(self.query) # Shape: [num_queries, C]
-
- # [MODIFIED] 调整维度扩展方式以适应 batch-first
- # 将 query 从 [num_queries, C] 扩展到 [N, num_queries, C]
- q_embed = q.unsqueeze(0).expand(N, -1, -1) + self.query_pos_embed.unsqueeze(0)
-
- # [MODIFIED] 将 input_pos 从 [L, C] 扩展到 [1, L, C] 以便与 kv_embed [N, L, C] 相加
- kv_embed = kv_embed + input_pos
-
- if self.training and self.gradient_checkpointing:
- q_embed.requires_grad_(True)
- kv_embed.requires_grad_(True)
- out = checkpoint(self._checkpointed_forward, q_embed, kv_embed, use_reentrant=False)
- else:
- out = self._checkpointed_forward(q_embed, kv_embed)
-
- return out
-
- def enable_input_require_grads(self):
- print_log("enable input required grads for projector", 'current')
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- self.model.register_forward_hook(make_inputs_require_grad)
-
- def gradient_checkpointing_enable(self):
- self.gradient_checkpointing = True
-
- def gradient_checkpointing_disable(self):
- self.gradient_checkpointing = False
-
- def _repeat(self, query, N: int):
- return query.unsqueeze(1).repeat(1, N, 1)
-# =================================================================================================
-# End of Resampler code
-# =================================================================================================
-
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- def __init__(self, output_size):
- super(AdaptiveAvgPool1dLayer, self).__init__()
- self.output_size = output_size
-
- def forward(self, x):
- return F.adaptive_avg_pool1d(x, self.output_size)
-
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
-
- # slide/pos-embed 参数
- slide_ngrids=1000,
- tile_size=224,
-
- # 各子模块权重路径
- projector_pth=None,
- resampler_pth=None,
- token_merge_pth=None,
-
- # Token Merge
- enable_token_merge=True,
-
- # Resampler 配置
- use_resampler=True,
- resampler_num_latents=256,
- resampler_heads = 16,
-
- # === 新增:Stage-2 冻结选项 ===
- freeze_mm_in_stage2=False, # 总开关:在 stage-2 冻结 projector / resampler / token_merge
- freeze_projector_stage2=None, # 子开关(None 表示跟随总开关)
- freeze_resampler_stage2=None, # 子开关(None 表示跟随总开关)
- freeze_token_merge_stage2=None # 子开关(None 表示跟随总开关)
- ):
- super().__init__()
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- self.tile_size = tile_size
-
- # 训练阶段控制
- if train_stage == '0':
- print_log('train_stage == 0', 'current')
- self.freeze_llm = True
- if train_stage == '1':
- print_log('train_stage == 1', 'current')
- self.freeze_llm = True
- elif train_stage == '2':
- print_log('train_stage == 2', 'current')
- self.freeze_llm = False
-
- # 解析 stage-2 的冻结意图
- def _resolve(flag):
- return freeze_mm_in_stage2 if flag is None else bool(flag)
- self._freeze_projector_in_s2 = _resolve(freeze_projector_stage2)
- self._freeze_resampler_in_s2 = _resolve(freeze_resampler_stage2)
- self._freeze_token_merge_in_s2 = _resolve(freeze_token_merge_stage2)
-
- # 构建 / 派发 LLM
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- # Token Merge
- self.enable_token_merge = enable_token_merge
- if self.enable_token_merge:
- self.token_merge = SparsePatchMerging(
- embed_dim=hidden_size,
- layernorm_eps=1e-6,
- merge_size=2
- )
-
- # Projector
- self.projector_depth = projector_depth
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size * 4 if self.enable_token_merge else hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth
- )
- self.projector = ProjectorModel(projector_config).to(self.llm.dtype)
- self.projector.requires_grad_(True)
-
- # Resampler
- self.use_resampler = use_resampler
- self.slide_ngrids = slide_ngrids
- if self.use_resampler:
- self.resampler_num_latents = resampler_num_latents
- print_log(f'using simple Resampler with {resampler_num_latents} latents', 'current')
- self.resampler = Resampler(
- grid_size=int(math.sqrt(self.resampler_num_latents)),
- embed_dim=self.llm.config.hidden_size,
- num_heads=resampler_heads,
- kv_dim=self.llm.config.hidden_size,
- ).to(self.llm.dtype)
-
-
- # 冻结 LLM
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- # 激活检查点(按需对冻结模块跳过 input-grad 使能)
- if use_activation_checkpointing:
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- # Resampler is a simple nn.Module and does not have this method.
- # If checkpointing is desired for it, its forward pass should be wrapped.
- # For this modification, we will omit its specific checkpointing setup.
-
- _projector_frozen = (train_stage == '2' and self._freeze_projector_in_s2)
- if not _projector_frozen:
- print('enable projector input require grads')
- print_log('enable projector input require grads', 'current')
- self.projector.enable_input_require_grads()
- else:
- print_log('[stage-2] Skipping projector.enable_input_require_grads() (frozen)', 'current')
-
- # 启用激活检查点
- self.gradient_checkpointing_enable()
-
- # LoRA
- self.use_llm_lora = llm_lora is not None
- self.use_visual_encoder_lora = None
- if self.use_llm_lora:
- print_log(f"Building lora {llm_lora.__str__}", "current")
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
- self.verify_lora()
-
- # 加载 token_merge / projector / resampler 的 safetensors
- if token_merge_pth is not None and enable_token_merge and hasattr(self, 'token_merge'):
- print_log(f'loading token_merge from {token_merge_pth}', 'current')
- merger_sd = load_file(token_merge_pth, device='cpu')
- self.token_merge.load_state_dict(merger_sd, strict=False)
- self.token_merge.to(self.llm.dtype)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if resampler_pth is not None and self.use_resampler and hasattr(self, 'resampler'):
- print_log(f'Loading resampler from {resampler_pth}', 'current')
- resampler_sd = load_file(resampler_pth, device="cpu")
- self.resampler.load_state_dict(resampler_sd, strict=False)
- self.resampler.to(self.llm.dtype)
-
- # 额外加载 float 权重(可选)
- if pretrained_pth is not None:
- sd = guess_load_checkpoint(pretrained_pth)
- model_sd = self.state_dict()
- filtered = {k: v for k, v in sd.items() if k in model_sd and model_sd[k].shape == v.shape}
- missing, unexpected = self.load_state_dict(filtered, strict=False)
- print_log(f"Loaded float ckpt from {pretrained_pth}", "current")
- print_log(f" missing: {missing}", "current")
- print_log(f" unexpected:{unexpected}", "current")
-
- # 记录可视层
- self.visual_select_layer = visual_select_layer
-
- # 初始化标志
- self._is_init = True
- self.is_first_iter = True
-
- # === 关键新增:在 Stage-2 按需冻结三个多模态子模块 ===
- if train_stage == '2':
- # projector
- if hasattr(self, 'projector') and self._freeze_projector_in_s2:
- self.projector.requires_grad_(False)
- self.projector.eval()
- print_log('[stage-2] Freezing projector parameters', 'current')
-
- # resampler
- if getattr(self, 'use_resampler', False) and hasattr(self, 'resampler') and self._freeze_resampler_in_s2:
- self.resampler.requires_grad_(False)
- self.resampler.eval()
- print_log('[stage-2] Freezing resampler parameters', 'current')
-
- # token_merge
- if getattr(self, 'enable_token_merge', False) and hasattr(self, 'token_merge') and self._freeze_token_merge_in_s2:
- self.token_merge.requires_grad_(False)
- self.token_merge.eval()
- print_log('[stage-2] Freezing token_merge parameters', 'current')
-
-
-
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- # we use xavier_uniform following official JAX ViT:
- torch.nn.init.xavier_uniform_(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def verify_lora(self):
- m = self.llm
-
- # 1) Wrapped as a PEFT model
- assert isinstance(m, PeftModel), "LoRA not applied: model is not a PeftModel"
-
- # 2) Adapters are registered and active
- adapters = m.peft_config # dict: {adapter_name: LoraConfig}
- assert len(adapters) > 0, "No adapters registered in peft_config"
- active = m.active_adapter if hasattr(m, "active_adapter") else None
- assert active in adapters, f"Active adapter {active} not found in peft_config"
-
- # 3) LoRA layers are present on target modules
- lora_modules = [mod for mod in m.modules() if isinstance(mod, LoraLayer)]
- assert len(lora_modules) > 0, "No LoraLayer modules found (check target_modules)"
-
- # 4) LoRA params are the only trainable ones (typical for QLoRA)
- trainable = [(n,p) for n,p in m.named_parameters() if p.requires_grad]
- assert len(trainable) > 0, "No trainable parameters (LoRA params are not set to requires_grad=True)"
- # Optional: sanity-check that trainable params look like LoRA
- suspicious = [n for n,_ in trainable if "lora_" not in n and "modules_to_save" not in n]
- # It's okay if you intentionally left some modules_to_save; adjust as needed.
- assert len(suspicious) == 0, f"Unexpected trainable params (not LoRA): {suspicious[:5]}"
-
- # 5) Quick count + readable log
- total = sum(p.numel() for _,p in m.named_parameters())
- trainable_cnt = sum(p.numel() for _,p in trainable)
- ratio = trainable_cnt / total
- print(f"[LoRA OK] adapters={list(adapters.keys())}, active={active}, "
- f"LoraLayers={len(lora_modules)}, trainable={trainable_cnt}/{total} ({ratio:.4%})")
-
- # 6) Forward+backward smoke test to confirm gradients flow to LoRA only
- m.train()
- dummy_inp = torch.randint(0, m.get_input_embeddings().num_embeddings, (1, 8)).to(next(m.parameters()).device)
- out = m(input_ids=dummy_inp, labels=dummy_inp)
- out.loss.backward() # should not error
- # Ensure some LoRA grads exist
- lora_grads = [p.grad for _,p in m.named_parameters() if p.requires_grad and p.grad is not None]
- assert len(lora_grads) > 0, "No gradients on LoRA parameters after backward()"
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self, use_reentrant=False):
- self.activation_checkpointing_enable(use_reentrant=use_reentrant)
-
- def activation_checkpointing_enable(self, use_reentrant=False):
- # LLM
- try:
- self.llm.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except TypeError:
- # older HF versions
- self.llm.gradient_checkpointing_enable()
-
- # projector
- try:
- self.projector.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except TypeError:
- self.projector.gradient_checkpointing_enable()
-
- if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
- try:
- self.resampler.gradient_checkpointing_enable(use_reentrant=use_reentrant)
- except:
- self.resampler.gradient_checkpointing_enable()
-
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
- if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
- self.resampler.gradient_checkpointing_disable()
-
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
-
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- # Step 4. Resampler
- if getattr(self, 'use_resampler', False) and getattr(self, 'resampler', None) is not None:
- to_return.update({k: v for k, v in state_dict.items() if 'resampler.' in k})
-
- # step 5 token merger
- if getattr(self, 'token_merge', False):
- to_return.update({k: v for k, v in state_dict.items() if 'token_merge.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
-
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def coords_to_pos(self, coords, tile_size: int = 224):
- """
- This function is used to convert the coordinates to the positional indices
-
- Arguments:
- ----------
- coords: torch.Tensor
- The coordinates of the patches, of shape [N, L, 2]
- output: torch.Tensor
- The positional indices of the patches, of shape [N, L]
- """
- coords_ = torch.floor(coords / tile_size)
- pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
- return pos.long() # add 1 for the cls token
-
- @staticmethod
- def _coords_rc_to_pos(coords_rc: torch.Tensor, ngrids: int) -> torch.Tensor:
- if coords_rc.dtype.is_floating_point:
- coords_rc = coords_rc.round().to(torch.long)
- # row = coords_rc[:, 0].clamp_(0, ngrids-1)
- # col = coords_rc[:, 1].clamp_(0, ngrids-1)
- return (coords_rc[..., 0] * ngrids + coords_rc[..., 1]).long() # +1 for cls
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
- coords = None
-
- if 'pixel_values' in data:
-
- feat_to_proj = data['pixel_values'].to(self.llm.dtype) # torch.Size([1, img_num, 512])
- # ensure requires_grad for gradient checkpointing
- feat_to_proj.requires_grad_(True)
-
- if 'coords' in data:
- coords = data['coords'].to(self.llm.dtype)
- # Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor
- coords_t = coords[0] if isinstance(coords, list) else coords
- Bx = feat_to_proj.size(0) # actual batch size of inputs
- if not torch.is_tensor(coords_t):
- raise ValueError("coords must be a Tensor or list[Tensor].")
-
- if coords_t.dim() == 2:
- # [L, 2]
- coords_rc = coords_t
- elif coords_t.dim() == 3:
- # [B, L, 2] -> ensure B matches and either B==1 or all examples share coords
- if coords_t.size(0) != Bx:
- raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
- if Bx == 1:
- coords_rc = coords_t[0]
- else:
- # require same coords across the batch (cheap equality check)
- if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
- raise NotImplementedError(
- "Per-example coords (varying across batch) are not supported by the current "
- "patch-merging/layout path. Use batch size 1 or share coords across the batch."
- )
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
-
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
- else:
- raise RuntimeError
-
- # only works for batch size one
- if self.enable_token_merge:
- feat_to_proj, coords_rc_merged, _ = self.token_merge(
- x=feat_to_proj,
- coords_rc=self._coords_to_rowcol(coords_rc),
- padmask=torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
- device=feat_to_proj.device, dtype=torch.bool)
- )
- # print(f"After token_merge, feat_to_proj: {feat_to_proj.shape}, coords_rc_merged: {coords_rc_merged.shape}")
- else:
- coords_rc_merged = self._coords_to_rowcol(coords_rc)
- padmask_merged = torch.zeros([feat_to_proj.size(0), feat_to_proj.size(1)],
- device=feat_to_proj.device, dtype=torch.bool)
-
- pixel_values = self.projector(feat_to_proj.to(self.llm.dtype)) # output shape [1, patch_num, hidden_size]
- # print(f"After projector, pixel_values: {pixel_values.shape}")
- if self.use_resampler and getattr(self, 'resampler', None) is not None:
- pixel_values = self.resampler(pixel_values, coords_rc_merged,
- attn_mask= None) # [1, num_latents, hidden_size]
-
- data['pixel_values'] = pixel_values
-
- # remove coords
- data.pop('coords', None)
-
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- @staticmethod
- def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
- with torch.no_grad():
- x = coords_xy[:, 0]
- y = coords_xy[:, 1]
- x_for_unique = x
- y_for_unique = y
- if x_for_unique.dtype.is_floating_point:
- x_for_unique = x_for_unique.round().to(torch.int)
- y_for_unique = y_for_unique.round().to(torch.int)
- x_sorted = torch.unique(x_for_unique, sorted=True)
- y_sorted = torch.unique(y_for_unique, sorted = True)
-
- col = torch.searchsorted(x_sorted, x)
- row = torch.searchsorted(y_sorted, y)
- return torch.stack([row, col], dim=-1)
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- """
- 计算损失的修改版实现。
- 该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题,
- 使得每个样本对总损失的贡献相等,无论其token长度如何。
- """
- # 如果 HF 模型可以自己处理,则直接返回
- if "labels" not in data:
- outputs = self.llm(**data)
- return {"loss": outputs.loss}
-
- # 将 labels 从 data 中分离出来,避免其被直接传递给模型
- labels = data.pop("labels")
-
- # 模型前向传播,获取 logits
- outputs = self.llm(**data)
- logits = outputs.logits
-
- # 验证 logits 和 labels 的形状是否匹配
- if logits.shape[:-1] != labels.shape:
- raise ValueError(
- f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}"
- )
-
- # 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代
- # logits: [B, L, V] -> [L, B, V]
- # labels: [B, L] -> [L, B]
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
-
- # 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none')
- # 这将返回一个与 shift_labels 形状相同的损失张量
- loss = F.cross_entropy(
- shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1),
- ignore_index=-100,
- reduction='none'
- )
-
- # 将损失张量 reshape 回 [B, L-1]
- loss = loss.view(shift_logits.size(0), -1)
-
- # 对每个样本(每个序列)分别计算平均损失
- # 统计每个样本中有效(非-100)的 token 数量
- num_tokens_per_sample = (shift_labels != -100).sum(dim=1)
-
- # 计算每个样本的总损失
- loss_per_sample = loss.sum(dim=1)
-
- # 避免除以零
- valid_samples_mask = num_tokens_per_sample > 0
-
- # 初始化每个样本的平均损失
- mean_loss_per_sample = torch.zeros_like(loss_per_sample)
-
- # 只对有效的样本计算平均损失
- if valid_samples_mask.any():
- mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask]
-
- # 最终的损失是所有样本平均损失的平均值
- final_loss = mean_loss_per_sample.mean()
-
- return {"loss": final_loss}
-
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
- if self.use_resampler and hasattr(self, 'resampler'):
-
- resampler_path = osp.join(save_dir, "resampler")
- print_log(f'Saving Resampler to {resampler_path}', 'current')
- os.makedirs(resampler_path, exist_ok=True)
- resampler_output_path = os.path.join(resampler_path, 'resampler.safetensors')
- save_file(self.resampler.state_dict(), resampler_output_path)
-
- if self.enable_token_merge and hasattr(self, 'token_merge'):
- merger_path = osp.join(save_dir, 'token_merger')
- print_log(f'Saving token merger to{merger_path}', 'current')
- os.makedirs(merger_path, exist_ok= True)
- merger_path = os.path.join(merger_path, 'merger.safetensors')
- save_file(self.token_merge.state_dict(), merger_path)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- if self.use_resampler:
- warnings.warn("Conversion to HuggingFace LLaVA format with a custom resampler is not supported. "
- "The resampler weights will not be saved.")
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=False, assign=True) # strict=False to ignore missing resampler
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- if self.use_resampler:
- warnings.warn("Conversion to official LLaVA format with a custom resampler is not supported. "
- "The resampler weights will not be saved.")
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=False, assign=True) # strict=False to ignore missing resampler
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_only_projector.py b/code/xtuner/model/llava_only_projector.py
deleted file mode 100644
index 61a48c3fe0446c7796170e968768156e2060afc2..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_only_projector.py
+++ /dev/null
@@ -1,938 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-import os
-from safetensors.torch import load_file, save_file
-
-import torch
-import torch.distributed as dist # === MOD ===
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-import torch.nn.functional as F
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- """
- 自适应平均池化层(沿序列维 L),带输入/输出 LayerNorm,并在大 L 时切换为线性插值,
- 避免 CUDA AdaptiveAvgPool 的 sharedMem 限制导致的报错。
-
- 期望输入:x ∈ [B, H, L]
- - 先在 [B, L, H] 上做输入层归一化(LayerNorm(H))。
- - 对序列维 L 做池化/插值到 output_size。
- - 再在 [B, L_out, H] 上做输出层归一化。
-
- 参数:
- output_size (int): 池化后的 token 数 L_out。
- hidden_size (int): 通道维 H 的大小(用于 LayerNorm 维度)。
- eps (float): LayerNorm eps。
- affine (bool): LayerNorm 是否带缩放平移参数。
- impl (str): 'auto' | 'pool' | 'interp'。auto 根据长度阈值自动切换。
- switch_threshold (int): 当 L >= 该阈值且 impl='auto' 时使用插值。
- pool_in_fp32 (bool): 池化/插值内部提升到 FP32 计算以增强数稳。
- """
- def __init__(self, output_size: int, hidden_size: int, eps: float = 1e-5, affine: bool = True,
- impl: str = 'auto', switch_threshold: int = 8192, pool_in_fp32: bool = True):
- super().__init__()
- if output_size <= 0:
- raise ValueError("output_size must be positive")
- if hidden_size <= 0:
- raise ValueError("hidden_size must be positive")
- if impl not in ('auto', 'pool', 'interp'):
- raise ValueError("impl must be one of {'auto','pool','interp'}")
- self.output_size = int(output_size)
- self.hidden_size = int(hidden_size)
- self.impl = impl
- self.switch_threshold = int(switch_threshold)
- self.pool_in_fp32 = bool(pool_in_fp32)
- self.in_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
- self.out_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # 期待 x 形状为 [B, H, L]
- if x.dim() != 3:
- raise ValueError(f"AdaptiveAvgPool1dLayer expects 3D tensor [B,H,L], got {tuple(x.shape)}")
- B, H, L = x.shape
- if H != self.hidden_size:
- raise ValueError(f"Channel size mismatch: got H={H}, expected {self.hidden_size}")
-
- # 输入归一化:在 [B, L, H] 上做 LayerNorm(H)
- x = x.transpose(1, 2).contiguous() # [B, L, H]
- x = self.in_norm(x)
-
- x = x.transpose(1, 2).contiguous() # [B, H, L]
-
- # 选择实现:大 L 时使用插值以避免 CUDA sharedMem 报错
- use_interp = (self.impl == 'interp') or (self.impl == 'auto' and L >= self.switch_threshold)
- orig_dtype = x.dtype
- if self.pool_in_fp32 and x.dtype in (torch.float16, torch.bfloat16):
- x = x.float()
- if use_interp:
- # 线性插值在 [B, H, L] 上稳定可导
- x = F.interpolate(x, size=self.output_size, mode='linear', align_corners=False)
- else:
- x = F.adaptive_avg_pool1d(x.contiguous(), self.output_size)
- x = x.to(orig_dtype)
-
- # 输出归一化:在 [B, L_out, H] 上做 LayerNorm(H)
- x = x.transpose(1, 2).contiguous() # [B, L_out, H]
- x = self.out_norm(x)
- x = x.transpose(1, 2).contiguous() # [B, H, L_out]
- return x
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
-
- projector_pth=None,
-
- use_projector_pool = False,
- projector_pool_out_tokens = 1024,
- projector_pool_pth = None,
- projector_pool_ln_eps = 1e-6,
- projector_pool_ln_affine = True,
- ):
- super().__init__()
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- self.use_projector_pool = use_projector_pool
- if self.use_projector_pool:
- hs = int(self.llm.config.hidden_size)
- self.projector_pool = AdaptiveAvgPool1dLayer(
- output_size=int(projector_pool_out_tokens),
- hidden_size=hs,
- eps=float(projector_pool_ln_eps),
- affine=bool(projector_pool_ln_affine),
- impl= 'auto',
- switch_threshold= 10240,
- pool_in_fp32= True,
- )
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
- self.projector.enable_input_require_grads()
- # self.visual_encoder.enable_input_require_grads() # if used
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- # proj_sd = load_file(projector_pth, device="cuda")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if pretrained_pth is not None:
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}',
- 'current')
-
- self.visual_select_layer = visual_select_layer
-
- self._is_init = True
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector_pool.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
-
- # ======================= FIX =======================
- # Explicitly enable gradient tracking for the input features.
- # This is the crucial step to connect the backpropagation graph
- # to the projector's weights.
- feat_to_proj.requires_grad_(True)
- # ===================================================
-
- # The diagnostic code you had was good, but this makes it proactive.
- # You can now remove the old `if using_proj_ckpt:` block
- # as this solves the root cause.
-
- pixel_values = self.projector(feat_to_proj) # Pass the grad-enabled tensor
-
- # === NEW: pool along the sequence length (tokens) to L'
- if self.use_projector_pool:
- B, L, H = pixel_values.shape
- pv = pixel_values.transpose(1, 2) # [B, H, L]
- pv = self.projector_pool(pv) # [B, H, L']
- pixel_values = pv.transpose(1, 2).contiguous() # [B, L', H]
-
- data['pixel_values'] = pixel_values
-
- data.pop('coords', None)
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
- outputs = self.llm(**data)
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
-
- # 替换 LLaVAModel 中的 compute_loss 函数
- def compute_loss(self, data, data_samples=None):
- """
- 计算损失的修改版实现。
- 该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题,
- 使得每个样本对总损失的贡献相等,无论其token长度如何。
- """
- # 如果 HF 模型可以自己处理,则直接返回
- if "labels" not in data:
- outputs = self.llm(**data)
- return {"loss": outputs.loss}
-
- # 将 labels 从 data 中分离出来,避免其被直接传递给模型
- labels = data.pop("labels")
-
- # 模型前向传播,获取 logits
- outputs = self.llm(**data)
- logits = outputs.logits
-
- # 验证 logits 和 labels 的形状是否匹配
- if logits.shape[:-1] != labels.shape:
- raise ValueError(
- f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}"
- )
-
- # 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代
- # logits: [B, L, V] -> [L, B, V]
- # labels: [B, L] -> [L, B]
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
-
- # 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none')
- # 这将返回一个与 shift_labels 形状相同的损失张量
- loss = F.cross_entropy(
- shift_logits.view(-1, shift_logits.size(-1)),
- shift_labels.view(-1),
- ignore_index=-100,
- reduction='none'
- )
-
- # 将损失张量 reshape 回 [B, L-1]
- loss = loss.view(shift_logits.size(0), -1)
-
- # 对每个样本(每个序列)分别计算平均损失
- # 统计每个样本中有效(非-100)的 token 数量
- num_tokens_per_sample = (shift_labels != -100).sum(dim=1)
-
- # 计算每个样本的总损失
- loss_per_sample = loss.sum(dim=1)
-
- # 避免除以零
- valid_samples_mask = num_tokens_per_sample > 0
-
- # 初始化每个样本的平均损失
- mean_loss_per_sample = torch.zeros_like(loss_per_sample)
-
- # 只对有效的样本计算平均损失
- if valid_samples_mask.any():
- mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask]
-
- # 最终的损失是所有样本平均损失的平均值
- final_loss = mean_loss_per_sample.mean()
-
- return {"loss": final_loss}
-
- # def compute_loss(self, data, data_samples=None):
- # outputs = self.llm(**data)
- # loss_dict = {'loss': outputs.loss}
- # return loss_dict
-
- # def compute_loss(self, data, data_samples=None):
- # """
- # 计算 token-level 交叉熵损失(分布式/AMP 兼容)。
- # - labels 中 -100 为 ignore_index
- # - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置
- # """
-
- # # 1) 若无 labels,退回 HF 默认
- # if "labels" not in data:
- # outputs = self.llm(**data)
- # return {"loss": outputs.loss}
-
- # labels = data["labels"] # [B, T]
- # input_ids = data.get("input_ids", None) # [B, T] or None
- # attn = data.get("attention_mask", None) # 可无
-
- # # 2) 标签清洗(不改原 labels)
- # safe_labels = labels.clone()
-
- # # 2.1 屏蔽负 ID(如 -200 图像占位)
- # if input_ids is not None:
- # neg_mask = (input_ids < 0)
- # if neg_mask.any():
- # safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等)
- # if getattr(self, "tokenizer", None) is not None:
- # try:
- # special_ids = set(self.tokenizer.all_special_ids or [])
- # except Exception:
- # special_ids = set()
- # if special_ids:
- # special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- # for sid in special_ids:
- # special_mask |= (input_ids == sid)
- # if special_mask.any():
- # safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean)
- # model_inputs = {k: v for k, v in data.items() if k != "labels"}
- # outputs = self.llm(**model_inputs, use_cache=False)
- # logits = outputs.logits # [B, T, V]
-
- # # 形状断言
- # if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]:
- # raise RuntimeError(
- # f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
- # )
-
- # # 4) CausalLM 对齐
- # shift_logits = logits[:, :-1, :].contiguous()
- # shift_labels = safe_labels[:, 1:].contiguous()
-
- # # 5) 统计有效 token & 分布式聚合
- # n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- # world_size = 1
- # n_tok_global = n_tok_local
- # if dist.is_available() and dist.is_initialized():
- # world_size = dist.get_world_size()
- # with torch.no_grad():
- # n_tok_global = n_tok_local.clone()
- # dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # # 若全局无监督 token,则返回 0(防 NaN)
- # if n_tok_global.item() == 0:
- # zero = shift_logits.sum() * 0.0
- # return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
-
- # # 6) 分子(sum over tokens,FP32 更稳)
- # loss_sum_local = F.cross_entropy(
- # shift_logits.float().view(-1, shift_logits.size(-1)),
- # shift_labels.view(-1),
- # ignore_index=-100,
- # reduction="sum",
- # )
-
- # # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均)
- # denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
- # loss = (loss_sum_local / denom) * float(world_size)
-
- # # 8) 返回
- # ntok_tensor = denom.detach()
- # return {"loss": loss, "ntok": ntok_tensor}
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
-
- if self.use_projector_pool and getattr(self, 'projector_pool', None):
- projector_pool_path = osp.join(save_dir, 'projector_pool')
- print_log(f'Saving projector_pool to {projector_pool_path}', 'current')
- torch.save(self.projector_pool.state_dict(), projector_pool_path)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/llava_only_projector_detection.py b/code/xtuner/model/llava_only_projector_detection.py
deleted file mode 100644
index 2c8c5d9bc88fb34e769a3aae84096dfa4ec58fd7..0000000000000000000000000000000000000000
--- a/code/xtuner/model/llava_only_projector_detection.py
+++ /dev/null
@@ -1,1023 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import os.path as osp
-import warnings
-from collections import OrderedDict
-import os
-from safetensors.torch import load_file, save_file
-
-import torch
-import json
-from collections import deque
-
-import torch.distributed as dist # === MOD ===
-import torch.nn as nn
-from accelerate import init_empty_weights
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import (AddedToken, AutoConfig, CLIPImageProcessor,
- CLIPVisionModel, LlamaForCausalLM,
- LlamaTokenizerFast, LlavaConfig,
- LlavaForConditionalGeneration, LlavaProcessor)
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.registry import BUILDER
-from xtuner.utils import DEFAULT_IMAGE_TOKEN
-from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, guess_load_checkpoint,
- make_inputs_require_grad,
- prepare_inputs_labels_for_multimodal, traverse_dict)
-
-import torch.nn.functional as F
-
-def convert_state_dict_to_hf(state_dict, mapping):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.endswith('.inv_freq'):
- continue
- for key_to_modify, new_key in mapping.items():
- if key_to_modify in key:
- key = key.replace(key_to_modify, new_key)
- new_state_dict[key] = value
- return new_state_dict
-
-class AdaptiveAvgPool1dLayer(nn.Module):
- """
- 自适应平均池化层(沿序列维 L),带输入/输出 LayerNorm,并在大 L 时切换为线性插值,
- 避免 CUDA AdaptiveAvgPool 的 sharedMem 限制导致的报错。
-
- 期望输入:x ∈ [B, H, L]
- - 先在 [B, L, H] 上做输入层归一化(LayerNorm(H))。
- - 对序列维 L 做池化/插值到 output_size。
- - 再在 [B, L_out, H] 上做输出层归一化。
-
- 参数:
- output_size (int): 池化后的 token 数 L_out。
- hidden_size (int): 通道维 H 的大小(用于 LayerNorm 维度)。
- eps (float): LayerNorm eps。
- affine (bool): LayerNorm 是否带缩放平移参数。
- impl (str): 'auto' | 'pool' | 'interp'。auto 根据长度阈值自动切换。
- switch_threshold (int): 当 L >= 该阈值且 impl='auto' 时使用插值。
- pool_in_fp32 (bool): 池化/插值内部提升到 FP32 计算以增强数稳。
- """
- def __init__(self, output_size: int, hidden_size: int, eps: float = 1e-5, affine: bool = True,
- impl: str = 'auto', switch_threshold: int = 8192, pool_in_fp32: bool = True):
- super().__init__()
- if output_size <= 0:
- raise ValueError("output_size must be positive")
- if hidden_size <= 0:
- raise ValueError("hidden_size must be positive")
- if impl not in ('auto', 'pool', 'interp'):
- raise ValueError("impl must be one of {'auto','pool','interp'}")
- self.output_size = int(output_size)
- self.hidden_size = int(hidden_size)
- self.impl = impl
- self.switch_threshold = int(switch_threshold)
- self.pool_in_fp32 = bool(pool_in_fp32)
- self.in_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
- self.out_norm = nn.LayerNorm(hidden_size, eps=eps, elementwise_affine=affine)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # 期待 x 形状为 [B, H, L]
- if x.dim() != 3:
- raise ValueError(f"AdaptiveAvgPool1dLayer expects 3D tensor [B,H,L], got {tuple(x.shape)}")
- B, H, L = x.shape
- if H != self.hidden_size:
- raise ValueError(f"Channel size mismatch: got H={H}, expected {self.hidden_size}")
-
- # 输入归一化:在 [B, L, H] 上做 LayerNorm(H)
- x = x.transpose(1, 2).contiguous() # [B, L, H]
- x = self.in_norm(x)
-
- x = x.transpose(1, 2).contiguous() # [B, H, L]
-
- # 选择实现:大 L 时使用插值以避免 CUDA sharedMem 报错
- use_interp = (self.impl == 'interp') or (self.impl == 'auto' and L >= self.switch_threshold)
- orig_dtype = x.dtype
- if self.pool_in_fp32 and x.dtype in (torch.float16, torch.bfloat16):
- x = x.float()
- if use_interp:
- # 线性插值在 [B, H, L] 上稳定可导
- x = F.interpolate(x, size=self.output_size, mode='linear', align_corners=False)
- else:
- x = F.adaptive_avg_pool1d(x.contiguous(), self.output_size)
- x = x.to(orig_dtype)
-
- # 输出归一化:在 [B, L_out, H] 上做 LayerNorm(H)
- x = x.transpose(1, 2).contiguous() # [B, L_out, H]
- x = self.out_norm(x)
- x = x.transpose(1, 2).contiguous() # [B, H, L_out]
- return x
-
-class LLaVAModel(BaseModel):
-
- def __init__(self,
- llm,
- tokenizer,
- freeze_llm=True,
- visual_select_layer=-2,
- pretrained_pth=None,
- projector_depth=2,
- llm_lora=None,
- visual_encoder_lora=None,
- use_activation_checkpointing=True,
- max_position_embeddings=None,
- hidden_size=512,
- train_stage='2',
-
- projector_pth=None,
-
- use_projector_pool = False,
- projector_pool_out_tokens = 1024,
- projector_pool_pth = None,
- projector_pool_ln_eps = 1e-6,
- projector_pool_ln_affine = True,
- ):
- super().__init__()
-
- self.freeze_llm = freeze_llm
- self.freeze_visual_encoder = True
- if train_stage == '1':
- print('train_stage == 1')
- self.freeze_llm = True
- elif train_stage == '2':
- print('train_stage == 2')
- self.freeze_llm = False
-
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
-
- self.llm = self._build_from_cfg_or_module(llm)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm)
-
- self.projector_depth = projector_depth
-
- projector_config = ProjectorConfig(
- visual_hidden_size=hidden_size,
- llm_hidden_size=self.llm.config.hidden_size,
- depth=self.projector_depth)
-
- self.projector = ProjectorModel(projector_config).to(
- self.llm.dtype)
-
- self.use_projector_pool = use_projector_pool
- if self.use_projector_pool:
- hs = int(self.llm.config.hidden_size)
- self.projector_pool = AdaptiveAvgPool1dLayer(
- output_size=int(projector_pool_out_tokens),
- hidden_size=hs,
- eps=float(projector_pool_ln_eps),
- affine=bool(projector_pool_ln_affine),
- impl= 'auto',
- switch_threshold= 10240,
- pool_in_fp32= True,
- )
-
-
- if self.freeze_llm:
- print('freeze_llm')
- self.llm.requires_grad_(False)
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
- self.projector.enable_input_require_grads()
- # self.visual_encoder.enable_input_require_grads() # if used
-
- # enable gradient (activation) checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- self.use_llm_lora = None
- self.use_visual_encoder_lora = None
-
- if self.use_llm_lora:
- self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing)
-
- if projector_pth is not None:
- print_log(f"Loading projector from {projector_pth}", "current")
- proj_sd = load_file(projector_pth, device="cpu")
- self.projector.load_state_dict(proj_sd, strict=False)
- self.projector.to(self.llm.dtype)
-
- if pretrained_pth is not None:
- pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
- self.load_state_dict(pretrained_state_dict, strict=False)
- print_log(f'Load pretrained weight from {pretrained_pth}',
- 'current')
-
- self.visual_select_layer = visual_select_layer
-
- # Pass the tokenizer instance to the model
- # if isinstance(tokenizer, dict):
- # # self.tokenizer = self._build_from_cfg_or_module(tokenizer)
- # else:
- # self.tokenizer = tokenizer
- self.tokenizer = self._build_from_cfg_or_module(tokenizer)
- # --- ADD SPIKE DETECTION LOGIC ---
- # A deque to store the last 50 loss values
- self.loss_history = deque(maxlen=50)
- # The multiplier for detecting a spike (e.g., 5x the average)
- self.spike_threshold_multiplier = 5.0
- # An absolute threshold to avoid flagging tiny losses (e.g., 0.1 -> 0.5)
- self.spike_threshold_abs = 1.0
- # The file to log spike data to
- self.spike_dump_dir = "loss_spike_dumps"
- os.makedirs(self.spike_dump_dir, exist_ok=True)
- self.spike_log_file = os.path.join(self.spike_dump_dir, "spike_log.jsonl")
- # --- END OF SPIKE DETECTION LOGIC ---
-
- self._is_init = True
- self.is_first_iter = True
-
- def _parse_lora_config(self, lora_config):
- if isinstance(lora_config, dict) or isinstance(
- lora_config, Config) or isinstance(lora_config, ConfigDict):
- lora_config = BUILDER.build(lora_config)
- return lora_config
-
- def _prepare_llm_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.llm)
- lora_config.target_modules = modules
- self.llm = get_peft_model(self.llm, lora_config)
-
- def _prepare_visual_encoder_for_lora(self,
- lora_config,
- use_activation_checkpointing=True):
- lora_config = self._parse_lora_config(lora_config)
- if lora_config.target_modules is None:
- modules = find_all_linear_names(self.visual_encoder)
- lora_config.target_modules = modules
- self.visual_encoder = get_peft_model(self.visual_encoder, lora_config)
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
- # self.visual_encoder.gradient_checkpointing_enable()
- self.projector.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
- # self.visual_encoder.gradient_checkpointing_disable()
- self.projector.gradient_checkpointing_disable()
-
- def init_weights(self):
- pass
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- to_return = OrderedDict()
- # Step 1. visual_encoder
- if self.use_visual_encoder_lora:
- to_return.update(
- get_peft_model_state_dict(
- self.visual_encoder, state_dict=state_dict))
- elif not self.freeze_visual_encoder:
- to_return.update({
- k: v
- for k, v in state_dict.items() if 'visual_encoder.' in k
- })
- # Step 2. LLM
- if self.use_llm_lora:
- to_return.update(
- get_peft_model_state_dict(self.llm, state_dict=state_dict))
- elif not self.freeze_llm:
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'llm.' in k})
- # Step 3. Projector
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector.' in k})
-
- to_return.update(
- {k: v
- for k, v in state_dict.items() if 'projector_pool.' in k})
- return to_return
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
-
- orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
- if orig_rope_scaling is None:
- orig_rope_scaling = {'factor': 1}
-
- orig_rope_scaling_factor = orig_rope_scaling[
- 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
- if orig_ctx_len:
- orig_ctx_len *= orig_rope_scaling_factor
- if max_position_embeddings > orig_ctx_len:
- scaling_factor = float(
- math.ceil(max_position_embeddings / orig_ctx_len))
- llm_cfg.rope_scaling = {
- 'type': 'linear',
- 'factor': scaling_factor
- }
-
- # hardcode for internlm2
- llm_cfg.attn_implementation = 'flash_attention_2'
- cfg.config = llm_cfg
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg, llm_cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg, llm_cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- if self.is_first_iter:
- # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to
- # device
- # Only required in `LLaVAModel` .
- # We do not need this in `SupervisedFinetune` .
- self.to(data['input_ids'].device)
- self.is_first_iter = False
-
- # data_dict['pixel_values']=[[pixel_values of img1], [pixel_values of img2], ...]
- input_ids_copied = data.get('input_ids', None)
- if input_ids_copied is not None:
- input_ids_copied = input_ids_copied.clone().detach()
-
- if 'pixel_values' in data:
- feat_to_proj = data['pixel_values'].to(self.llm.dtype)
-
- # ======================= FIX =======================
- # Explicitly enable gradient tracking for the input features.
- # This is the crucial step to connect the backpropagation graph
- # to the projector's weights.
- feat_to_proj.requires_grad_(True)
- # ===================================================
-
- # The diagnostic code you had was good, but this makes it proactive.
- # You can now remove the old `if using_proj_ckpt:` block
- # as this solves the root cause.
-
- pixel_values = self.projector(feat_to_proj) # Pass the grad-enabled tensor
-
- # === NEW: pool along the sequence length (tokens) to L'
- if self.use_projector_pool:
- B, L, H = pixel_values.shape
- pv = pixel_values.transpose(1, 2) # [B, H, L]
- pv = self.projector_pool(pv) # [B, H, L']
- pixel_values = pv.transpose(1, 2).contiguous() # [B, L', H]
-
- data['pixel_values'] = pixel_values
-
- data.pop('coords', None)
- data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
-
- if mode == 'loss':
-
- labels = data.get('labels', None)
- loss_dict = self.compute_loss(data, data_samples)
-
- # --- ADD SPIKE DETECTION LOGIC ---
- current_loss = loss_dict['loss'].item()
- avg_loss = sum(self.loss_history) / len(self.loss_history) if self.loss_history else current_loss
-
- # Check for a spike
- is_spike = len(self.loss_history) == self.loss_history.maxlen and \
- current_loss > (avg_loss * self.spike_threshold_multiplier) and \
- current_loss > (avg_loss + self.spike_threshold_abs)
-
- if is_spike:
- # Decode the input and label tensors to text
- # We need to handle -100 in labels and pad tokens
- question_ids = input_ids_copied[0]
- label_ids = labels[0]
-
- # Find where the answer starts (the first non -100 value in labels)
- try:
- answer_start_index = (label_ids != -100).nonzero()[0][0].item()
-
- # ============================ FIX ============================
- # Filter out negative token IDs (e.g., -200 for image placeholders)
- # before decoding to prevent OverflowError with fast tokenizers.
- question_token_ids = question_ids[:answer_start_index]
- valid_question_ids = question_token_ids[question_token_ids >= 0]
- question_part = self.tokenizer.decode(valid_question_ids, skip_special_tokens=True)
- # ===========================================================
-
- # Decode only the part of the labels that is the answer
- answer_ids = label_ids[answer_start_index:]
- answer_ids = answer_ids[answer_ids != -100] # Filter out -100
- answer_part = self.tokenizer.decode(answer_ids, skip_special_tokens=True)
- except IndexError: # Happens if there are no labels
- # ============================ FIX ============================
- # Also apply filtering here in the fallback case
- valid_question_ids = question_ids[question_ids >= 0]
- question_part = self.tokenizer.decode(valid_question_ids, skip_special_tokens=True)
- # ===========================================================
- answer_part = "[DECODING_ERROR: No labels found]"
-
- # Prepare the data to be logged
- spike_info = {
- "loss": round(current_loss, 4),
- "avg_loss": round(avg_loss, 4),
- "image": data.get("image_path", ["N/A"])[0], # Assuming image_path is passed
- "question": question_part.strip(),
- "answer": answer_part.strip(),
- }
-
- # Save to the log file
- with open(self.spike_log_file, 'a') as f:
- f.write(json.dumps(spike_info) + '\n')
-
- # Add current loss to history for the next iteration
- self.loss_history.append(current_loss)
- # --- END OF SPIKE DETECTION LOGIC ---
- del input_ids_copied
-
- return loss_dict
-
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
- outputs = self.llm(**data)
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- def compute_loss(self, data, data_samples=None):
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- # def compute_loss(self, data, data_samples=None):
- # """
- # 计算 token-level 交叉熵损失(分布式/AMP 兼容)。
- # - labels 中 -100 为 ignore_index
- # - 自动屏蔽负 ID(如 -200 图像占位)与 special_ids 对应位置
- # """
-
- # # 1) 若无 labels,退回 HF 默认
- # if "labels" not in data:
- # outputs = self.llm(**data)
- # return {"loss": outputs.loss}
-
- # labels = data["labels"] # [B, T]
- # input_ids = data.get("input_ids", None) # [B, T] or None
- # attn = data.get("attention_mask", None) # 可无
-
- # # 2) 标签清洗(不改原 labels)
- # safe_labels = labels.clone()
-
- # # 2.1 屏蔽负 ID(如 -200 图像占位)
- # if input_ids is not None:
- # neg_mask = (input_ids < 0)
- # if neg_mask.any():
- # safe_labels = torch.where(neg_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # 2.2 屏蔽 tokenizer 的特殊 token(模板标记等)
- # if getattr(self, "tokenizer", None) is not None:
- # try:
- # special_ids = set(self.tokenizer.all_special_ids or [])
- # except Exception:
- # special_ids = set()
- # if special_ids:
- # special_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- # for sid in special_ids:
- # special_mask |= (input_ids == sid)
- # if special_mask.any():
- # safe_labels = torch.where(special_mask, torch.full_like(safe_labels, -100), safe_labels)
-
- # # 3) 前向,拿 logits(不把 labels 交给 HF,避免其先做 per-device mean)
- # model_inputs = {k: v for k, v in data.items() if k != "labels"}
- # outputs = self.llm(**model_inputs, use_cache=False)
- # logits = outputs.logits # [B, T, V]
-
- # # 形状断言
- # if logits.dim() != 3 or logits.shape[:2] != safe_labels.shape[:2]:
- # raise RuntimeError(
- # f"logits/labels length mismatch: logits {tuple(logits.shape)} vs labels {tuple(safe_labels.shape)}"
- # )
-
- # # 4) CausalLM 对齐
- # shift_logits = logits[:, :-1, :].contiguous()
- # shift_labels = safe_labels[:, 1:].contiguous()
-
- # # 5) 统计有效 token & 分布式聚合
- # n_tok_local = (shift_labels != -100).sum().to(device=logits.device, dtype=torch.long)
-
- # world_size = 1
- # n_tok_global = n_tok_local
- # if dist.is_available() and dist.is_initialized():
- # world_size = dist.get_world_size()
- # with torch.no_grad():
- # n_tok_global = n_tok_local.clone()
- # dist.all_reduce(n_tok_global, op=dist.ReduceOp.SUM)
-
- # # 若全局无监督 token,则返回 0(防 NaN)
- # if n_tok_global.item() == 0:
- # zero = shift_logits.sum() * 0.0
- # return {"loss": zero, "ntok": n_tok_global.to(zero.dtype)}
-
- # # 6) 分子(sum over tokens,FP32 更稳)
- # loss_sum_local = F.cross_entropy(
- # shift_logits.float().view(-1, shift_logits.size(-1)),
- # shift_labels.view(-1),
- # ignore_index=-100,
- # reduction="sum",
- # )
-
- # # 7) 全局 token 平均的 loss(抵消 DDP 的梯度平均)
- # denom = n_tok_global.clamp_min(1).to(loss_sum_local.dtype)
- # loss = (loss_sum_local / denom) * float(world_size)
-
- # # 8) 返回
- # ntok_tensor = denom.detach()
- # return {"loss": loss, "ntok": ntok_tensor}
- # def compute_loss(self, data, data_samples=None):
- # """
- # 计算损失的修改版实现。
- # 该版本通过计算批次中每个样本的平均损失来解决长短文本的梯度失衡问题,
- # 使得每个样本对总损失的贡献相等,无论其token长度如何。
- # """
- # # 如果 HF 模型可以自己处理,则直接返回
- # if "labels" not in data:
- # outputs = self.llm(**data)
- # return {"loss": outputs.loss}
-
- # # 将 labels 从 data 中分离出来,避免其被直接传递给模型
- # labels = data.pop("labels")
-
- # # 模型前向传播,获取 logits
- # outputs = self.llm(**data)
- # logits = outputs.logits
-
- # # 验证 logits 和 labels 的形状是否匹配
- # if logits.shape[:-1] != labels.shape:
- # raise ValueError(
- # f"Logits and labels shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}"
- # )
-
- # # 将 Logits 和 Labels 的 batch 维度移动到第一维,方便迭代
- # # logits: [B, L, V] -> [L, B, V]
- # # labels: [B, L] -> [L, B]
- # shift_logits = logits[..., :-1, :].contiguous()
- # shift_labels = labels[..., 1:].contiguous()
-
- # # 使用 cross_entropy 计算每个 token 的损失,但不对其进行任何聚合 (reduction='none')
- # # 这将返回一个与 shift_labels 形状相同的损失张量
- # loss = F.cross_entropy(
- # shift_logits.view(-1, shift_logits.size(-1)),
- # shift_labels.view(-1),
- # ignore_index=-100,
- # reduction='none'
- # )
-
- # # 将损失张量 reshape 回 [B, L-1]
- # loss = loss.view(shift_logits.size(0), -1)
-
- # # 对每个样本(每个序列)分别计算平均损失
- # # 统计每个样本中有效(非-100)的 token 数量
- # num_tokens_per_sample = (shift_labels != -100).sum(dim=1)
-
- # # 计算每个样本的总损失
- # loss_per_sample = loss.sum(dim=1)
-
- # # 避免除以零
- # valid_samples_mask = num_tokens_per_sample > 0
-
- # # 初始化每个样本的平均损失
- # mean_loss_per_sample = torch.zeros_like(loss_per_sample)
-
- # # 只对有效的样本计算平均损失
- # if valid_samples_mask.any():
- # mean_loss_per_sample[valid_samples_mask] = loss_per_sample[valid_samples_mask] / num_tokens_per_sample[valid_samples_mask]
-
- # # 最终的损失是所有样本平均损失的平均值
- # final_loss = mean_loss_per_sample.mean()
- # # print(final_loss)
- # return {"loss": final_loss}
-
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- save_format='xtuner',
- **kwargs):
- if save_format == 'xtuner':
- self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- elif save_format == 'huggingface':
- self.to_huggingface_llava(cfg, save_dir, fp32,
- save_pretrained_kwargs)
- elif save_format == 'official':
- self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs)
- else:
- raise NotImplementedError
-
- def to_xtuner_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
- # LLM
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_llm_lora:
- llm_path = osp.join(save_dir, 'llm_adapter')
- print_log(f'Saving LLM adapter to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- elif not self.freeze_llm:
- llm_path = save_dir
- print_log(f'Saving LLM tokenizer to {llm_path}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs)
- print_log(f'Saving LLM to {llm_path}', 'current')
- self.llm.save_pretrained(llm_path, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
-
- # Visual Encoder
- if self.use_visual_encoder_lora:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter')
- print_log(
- f'Saving visual_encoder adapter to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- elif not self.freeze_visual_encoder:
- visual_encoder_path = osp.join(save_dir, 'visual_encoder')
- print_log(
- 'Saving visual_encoder image_processor to'
- f'{visual_encoder_path}', 'current')
- image_processor = BUILDER.build(cfg.image_processor)
- image_processor.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
- print_log(f'Saving visual_encoder to {visual_encoder_path}',
- 'current')
- self.visual_encoder.save_pretrained(visual_encoder_path,
- **save_pretrained_kwargs)
-
- # Projector
- projector_path = osp.join(save_dir, 'projector')
- print_log(f'Saving projector to {projector_path}', 'current')
- # self.projector.save_pretrained(projector_path,
- # **save_pretrained_kwargs)
- os.makedirs(projector_path, exist_ok=True)
- output_path = os.path.join(projector_path, 'projector.safetensors')
- save_file(self.projector.state_dict(), output_path)
-
-
- if self.use_projector_pool and getattr(self, 'projector_pool', None):
- projector_pool_path = osp.join(save_dir, 'projector_pool')
- print_log(f'Saving projector_pool to {projector_pool_path}', 'current')
- torch.save(self.projector_pool.state_dict(), projector_pool_path)
-
- def to_huggingface_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- LLM_MAPPING = {
- 'model': 'language_model.model',
- 'lm_head': 'language_model.lm_head',
- }
- VIT_MAPPING = {
- 'vision_model': 'vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'multi_modal_projector.linear_1',
- 'model.2': 'multi_modal_projector.linear_2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
- llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING)
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- text_config = llm.config
- vision_config = visual_encoder.config
- config = LlavaConfig(
- text_config=text_config,
- vision_config=vision_config,
- attn_implementation='eager')
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaForConditionalGeneration(config)
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # processor
- cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained
- tokenizer = BUILDER.build(cfg.tokenizer)
-
- tokenizer.add_tokens(
- AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False),
- special_tokens=True)
- tokenizer.add_special_tokens({'pad_token': ''})
-
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- processor = LlavaProcessor(
- tokenizer=tokenizer, image_processor=image_processor)
-
- # Pad to 64 for performance reasons
- pad_shape = 64
-
- pre_expansion_embeddings = \
- model.language_model.model.embed_tokens.weight.data
- mu = torch.mean(pre_expansion_embeddings, dim=0).float()
- n = pre_expansion_embeddings.size()[0]
- sigma = ((pre_expansion_embeddings - mu).T
- @ (pre_expansion_embeddings - mu)) / n
- dist = torch.distributions.multivariate_normal.MultivariateNormal(
- mu, covariance_matrix=1e-5 * sigma)
-
- # We add an image token so we need to resize the model
- ori_vocab_size = config.text_config.vocab_size
- tokenizer_vocab_size = tokenizer.encode('')[-1]
- added_token = tokenizer_vocab_size - ori_vocab_size
-
- if added_token > 0:
- model.resize_token_embeddings(ori_vocab_size + added_token,
- pad_shape)
- model.language_model.model.embed_tokens.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(
- dist.sample()
- for _ in range(model.language_model.model.embed_tokens.
- weight.data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.language_model.lm_head.weight.data[
- ori_vocab_size:] = torch.stack(
- tuple(dist.sample()
- for _ in range(model.language_model.lm_head.weight.
- data[ori_vocab_size:].shape[0])),
- dim=0,
- )
- model.config.image_token_index = tokenizer.encode(
- DEFAULT_IMAGE_TOKEN)[-1]
- model.config.pad_token_id = tokenizer.encode('')[-1]
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- processor.save_pretrained(save_dir, **save_pretrained_kwargs)
-
- def to_official_llava(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={}):
-
- VIT_MAPPING = {
- 'vision_model': 'model.vision_tower.vision_tower.vision_model',
- }
- PROJECTOR_MAPPING = {
- 'model.0': 'model.mm_projector.0',
- 'model.2': 'model.mm_projector.2',
- }
- LONGNET_MAPPING = {
- 'layers.0': 'LongNet_encoder.layers.0',
- 'layers.1': 'LongNet_encoder.layers.1',
- 'layer_norm': 'LongNet_encoder.layer_norm'
- }
-
- try:
- from llava.model import LlavaConfig, LlavaLlamaForCausalLM
- except ImportError:
- raise ImportError(
- 'Please install llava with '
- '`pip install git+https://github.com/haotian-liu/LLaVA.git '
- '--no-deps`.')
-
- assert getattr(self.llm, 'hf_quantizer', None) is None, \
- 'This conversion format does not support quantized LLM.'
-
- # get state_dict
- llm = self.llm
- if self.use_llm_lora:
- llm = self.llm.merge_and_unload()
- llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- llm.half()
-
- assert isinstance(llm, LlamaForCausalLM), \
- 'This conversion format only supports LlamaForCausalLM.'
- llm_state_dict = llm.state_dict()
-
- need_visual_encoder = (not self.freeze_visual_encoder
- or self.use_visual_encoder_lora)
- visual_encoder = self.visual_encoder
- if self.use_visual_encoder_lora:
- visual_encoder = self.visual_encoder.merge_and_unload()
- assert isinstance(visual_encoder, CLIPVisionModel),\
- 'This conversion format only supports CLIPVisionModel.'
- if need_visual_encoder:
- visual_encoder_state_dict = visual_encoder.state_dict()
- visual_encoder_state_dict = convert_state_dict_to_hf(
- visual_encoder_state_dict, VIT_MAPPING)
- else:
- visual_encoder_state_dict = {}
-
- projector_state_dict = self.projector.state_dict()
- projector_state_dict = convert_state_dict_to_hf(
- projector_state_dict, PROJECTOR_MAPPING)
-
- LongNet_encoder_state_dict = self.LongNet_encoder.state_dict()
- LongNet_encoder_state_dict = convert_state_dict_to_hf(
- LongNet_encoder_state_dict, LONGNET_MAPPING)
-
- state_dict = {
- **projector_state_dict,
- **llm_state_dict,
- **visual_encoder_state_dict,
- **LongNet_encoder_state_dict
- }
-
- # init model
- tokenizer = BUILDER.build(cfg.tokenizer)
- image_processor = BUILDER.build(cfg.image_processor)
- assert isinstance(image_processor, CLIPImageProcessor),\
- 'This conversion format only supports CLIPImageProcessor.'
-
- llava_config_dict = llm.config.__dict__.copy()
- llava_config_dict.update(
- dict(
- image_aspect_ratio='pad',
- mm_hidden_size=visual_encoder.config.hidden_size,
- mm_projector_type=f'mlp{self.projector_depth}x_gelu',
- mm_use_im_patch_token=False,
- mm_use_im_start_end=False,
- mm_vision_select_feature='patch',
- mm_vision_select_layer=self.visual_select_layer,
- mm_vision_tower=visual_encoder.config.name_or_path,
- unfreeze_mm_vision_tower=need_visual_encoder,
- model_type='llava',
- use_cache=True,
- use_mm_proj=True))
-
- llava_config = LlavaConfig(**llava_config_dict)
-
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = LlavaLlamaForCausalLM(llava_config)
-
- model.load_state_dict(state_dict, strict=True, assign=True)
-
- # save
- print_log(f'Saving to {save_dir}', 'current')
-
- model.save_pretrained(save_dir, **save_pretrained_kwargs)
- image_processor.save_pretrained(save_dir, **save_pretrained_kwargs)
- tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs)
\ No newline at end of file
diff --git a/code/xtuner/model/modules/__init__.py b/code/xtuner/model/modules/__init__.py
deleted file mode 100644
index 1207a9249708ff22b19db94a028b8d06f86f53a8..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .dispatch import dispatch_modules
-from .projector import ProjectorConfig, ProjectorModel
-
-__all__ = ['dispatch_modules', 'ProjectorConfig', 'ProjectorModel']
diff --git a/code/xtuner/model/modules/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/modules/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 4832097d7e5d39394f82b78441ea5e7311b19fe9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/modules/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/modules/dispatch/__init__.py b/code/xtuner/model/modules/dispatch/__init__.py
deleted file mode 100644
index e81ec7a3aa69fe25ee4a95759cdcb377e4e1ddd7..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/__init__.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import types
-
-import torch
-import transformers
-from mmengine.config.lazy import LazyObject
-from mmengine.utils import digit_version
-from transformers.utils.import_utils import is_flash_attn_2_available
-
-TRANSFORMERS_VERSION = digit_version(transformers.__version__)
-IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.38')
-# Transformers requires torch version >= 2.1.1 when using Torch SDPA.
-# Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611 # noqa: E501
-SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1')
-SUPPORT_FLASH2 = is_flash_attn_2_available()
-SUPPORT_FLASH = SUPPORT_FLASH1 or SUPPORT_FLASH2
-
-USE_TRITON_KERNEL = bool(os.getenv('USE_TRITON_KERNEL', default=0))
-SUPPORT_TRITON = False
-try:
- import triton # pre-check # noqa: F401
- import triton.language as tl # pre-check # noqa: F401
- SUPPORT_TRITON = True
-except ImportError:
- if USE_TRITON_KERNEL:
- raise RuntimeError(
- 'USE_TRITON_KERNEL is set to 1, but triton has not been installed.'
- ' Run `pip install triton==2.1.0` to install triton.')
-
-NO_ATTN_WEIGHTS_MSG = (
- 'Due to the implementation of the PyTorch version of flash attention, '
- 'even when the `output_attentions` flag is set to True, it is not '
- 'possible to return the `attn_weights`.')
-
-LOWEST_TRANSFORMERS_VERSION = dict(
- InternLM2ForCausalLM=digit_version('4.36'),
- InternLMForCausalLM=digit_version('4.36'),
- LlamaForCausalLM=digit_version('4.36'),
- Phi3ForCausalLM=digit_version('4.39'),
- MistralForCausalLM=digit_version('4.36'),
- # Training mixtral with lower version may lead to nccl timeout
- # Refer to https://github.com/microsoft/DeepSpeed/issues/5066
- MixtralForCausalLM=digit_version('4.40'),
- CohereForCausalLM=digit_version('4.40'),
- Qwen2ForCausalLM=digit_version('4.39'),
- Qwen2MoeForCausalLM=digit_version('4.40'),
- DeepseekV2ForCausalLM=digit_version('4.40'),
-)
-
-ATTN_DISPATCH_MAPPING = dict(
- InternLM2FlashAttention2=LazyObject(
- 'xtuner.model.modules.dispatch.internlm2', 'internlm2_attn_forward'),
- InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
- 'internlm_attn_forward'),
- LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
- 'llama_attn_forward'),
- Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
- 'phi3_attn_forward'),
- MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'mistral_attn_forward'),
- MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'mistral_attn_forward'),
- CohereFlashAttention2=LazyObject('xtuner.model.modules.dispatch.cohere',
- 'cohere_attn_forward'),
- Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
- 'qwen2_attn_forward'),
- Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
- 'qwen2_attn_forward'),
- DeepseekV2FlashAttention2=LazyObject(
- 'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_attn_forward'),
-)
-
-ATTN_LEGACY_DISPATCH_MAPPING = dict(
- LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
- 'llama_attn_forward_legacy'), )
-
-VARLEN_ATTN_DISPATCH_MAPPING = dict(
- InternLM2FlashAttention2=LazyObject(
- 'xtuner.model.modules.dispatch.internlm2',
- 'internlm2_varlen_attn_forward'),
- InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm',
- 'internlm_varlen_attn_forward'),
- LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
- 'llama_varlen_attn_forward'),
- Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3',
- 'phi3_varlen_attn_forward'),
- MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'mistral_varlen_attn_forward'),
- MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'mistral_varlen_attn_forward'),
- CohereFlashAttention2=None,
- Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
- 'qwen2_varlen_attn_forward'),
- Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2',
- 'qwen2_varlen_attn_forward'),
- DeepseekV2FlashAttention2=LazyObject(
- 'xtuner.model.modules.dispatch.deepseek_v2',
- 'deepseek_varlen_attn_forward'),
-)
-
-VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict(
- LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama',
- 'llama_varlen_attn_forward_legacy'), )
-
-RMS_DISPATCH_MAPPING = dict(
- InternLM2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- InternLMRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- LlamaRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- Phi3RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- MistralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- MixtralRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- CohereLayerNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'layer_norm_forward'),
- Qwen2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
- Qwen2MoeRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels',
- 'rms_norm_forward'),
-)
-
-ROTE_DISPATCH_MAPPING = dict(
- InternLMRotaryEmbedding=LazyObject(
- 'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'),
- MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'MistralRotaryEmbedding'),
- MixtralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
- 'MistralRotaryEmbedding'),
-)
-
-
-def log_once(func):
- logged = False
-
- def wrapper(*args, **kwargs):
- nonlocal logged
- if not logged:
- logged = True
- func(*args, **kwargs)
- return
-
- return wrapper
-
-
-def dispatch_attn_forward(model):
-
- if not SUPPORT_FLASH2:
- return
-
- from mmengine import print_log
- print_log = log_once(print_log)
-
- attn_forward = None
- for module in model.modules():
- name = type(module).__name__
- if (IS_LOW_VERSION_TRANSFORMERS
- and name in ATTN_LEGACY_DISPATCH_MAPPING):
- if attn_forward is None:
- attn_forward = ATTN_LEGACY_DISPATCH_MAPPING[name]
- attn_forward = attn_forward.build()
- print_log(f'Dispatch {name} legacy forward. {NO_ATTN_WEIGHTS_MSG}',
- 'current')
- module.forward = types.MethodType(attn_forward, module)
- elif name in ATTN_DISPATCH_MAPPING:
- if attn_forward is None:
- attn_forward = ATTN_DISPATCH_MAPPING[name]
- attn_forward = attn_forward.build()
- print_log(f'Dispatch {name} forward. {NO_ATTN_WEIGHTS_MSG}',
- 'current')
- module.forward = types.MethodType(attn_forward, module)
-
-
-def dispatch_varlen_attn_forward(model):
-
- if not SUPPORT_FLASH2:
- return
-
- from mmengine import print_log
- print_log = log_once(print_log)
-
- varlen_attn_forward = None
- for module in model.modules():
- name = type(module).__name__
- if (IS_LOW_VERSION_TRANSFORMERS
- and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING):
- if varlen_attn_forward is None:
- varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name]
- varlen_attn_forward = varlen_attn_forward.build()
- print_log(
- f'Dispatch legacy {name} varlen forward. '
- f'{NO_ATTN_WEIGHTS_MSG}', 'current')
- module.forward = types.MethodType(varlen_attn_forward, module)
- elif name in VARLEN_ATTN_DISPATCH_MAPPING:
- if varlen_attn_forward is None:
- varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name]
- varlen_attn_forward = varlen_attn_forward.build()
- print_log(f'Dispatch {name} varlen forward. {NO_ATTN_WEIGHTS_MSG}',
- 'current')
- module.forward = types.MethodType(varlen_attn_forward, module)
-
-
-def dispatch_rmsnorm_forward(model):
-
- if (not SUPPORT_TRITON) or (not USE_TRITON_KERNEL):
- return
-
- from mmengine import print_log
- print_log = log_once(print_log)
-
- rms_forward = None
- for module in model.modules():
- name = type(module).__name__
- if name in RMS_DISPATCH_MAPPING:
- if rms_forward is None:
- rms_forward = RMS_DISPATCH_MAPPING[name]
- rms_forward = rms_forward.build()
- print_log(f'Dispatch {name} forward.', 'current')
- module.forward = types.MethodType(rms_forward, module)
-
-
-def replace_rote(model):
-
- from mmengine import print_log
- print_log = log_once(print_log)
-
- def traverse(module):
- for name, child in module.named_children():
- cls_name = type(child).__name__
- if cls_name in ROTE_DISPATCH_MAPPING:
- assert hasattr(model.config, 'rope_theta'), \
- '`rope_theta` should be in the model config.'
- rope_theta = model.config.rope_theta
-
- rote = ROTE_DISPATCH_MAPPING[cls_name]
- rote = rote.build()
- print_log(f'replace {cls_name}', 'current')
- dim_model = child.inv_freq.shape[0] * 2
- child_new = rote(dim_model, child.max_seq_len_cached,
- rope_theta).to(
- device=child.inv_freq.device,
- dtype=child.inv_freq.dtype)
- setattr(module, name, child_new)
- else:
- traverse(child)
-
- traverse(model)
-
-
-def dispatch_modules(model, use_varlen_attn=False):
-
- def check(model_name):
- if 'ForCausalLM' not in model_name and model_name.endswith('Model'):
- # a walkaround for reward model
- model_name = model_name[:-5] + 'ForCausalLM'
- msg = '{} requires transformers version at least {}, but got {}'
- if model_name in LOWEST_TRANSFORMERS_VERSION:
- assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[
- model_name], msg.format(
- model_name, LOWEST_TRANSFORMERS_VERSION[model_name],
- TRANSFORMERS_VERSION)
-
- check(type(model).__name__)
- if use_varlen_attn:
- dispatch_varlen_attn_forward(model)
- else:
- dispatch_attn_forward(model)
- dispatch_rmsnorm_forward(model)
- replace_rote(model)
-
-
-__all__ = ['dispatch_modules']
diff --git a/code/xtuner/model/modules/dispatch/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/modules/dispatch/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index c31df187ce81c251bb847993cf8120c023d2b1f9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/modules/dispatch/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/modules/dispatch/attention.py b/code/xtuner/model/modules/dispatch/attention.py
deleted file mode 100644
index e89bb511cc946e521438c442caca97c1f594403b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/attention.py
+++ /dev/null
@@ -1,97 +0,0 @@
-from xtuner.parallel.sequence import sequence_parallel_wrapper
-from .utils import upad_qkv
-
-SUPPORT_FLASH2 = False
-
-try:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input
- SUPPORT_FLASH2 = True
-except ImportError:
- pass
-
-
-@sequence_parallel_wrapper
-def flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- dropout_p=0.0,
- softmax_scale=None,
- causal=True,
- window_size=(-1, -1), # -1 means infinite context window
-):
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout_p=dropout_p,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=window_size)
- return attn_output
-
-
-@sequence_parallel_wrapper
-def flash_attn_w_mask(
- query_states, # bs, q_len, nhead, h_dim
- key_states,
- value_states,
- attention_mask,
- softmax_scale=None,
- causal=True,
- dropout_p=0.0,
- window_size=(-1, -1), # -1 means infinite context window
-):
- batch_size, q_len = query_states.shape[:2]
- query_states, key_states, value_states, indices_q, \
- cu_seq_lens, max_seq_lens = upad_qkv(
- query_states, key_states, value_states, attention_mask, q_len)
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- softmax_scale=softmax_scale,
- dropout_p=dropout_p,
- causal=causal,
- window_size=window_size)
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
- return attn_output
-
-
-@sequence_parallel_wrapper
-def varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- softmax_scale=None,
- dropout_p=0.,
- causal=True,
- window_size=(-1, -1), # -1 means infinite context window
-):
- q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
- 0, 1), value_states.flatten(0, 1)
- attn_output = flash_attn_varlen_func(
- q_unpad,
- k_unpad,
- v_unpad,
- cumulative_len,
- cumulative_len,
- max_seqlen,
- max_seqlen,
- softmax_scale=softmax_scale,
- dropout_p=dropout_p,
- return_attn_probs=False,
- causal=causal,
- window_size=window_size)
- attn_output = attn_output.unsqueeze(0)
- return attn_output
diff --git a/code/xtuner/model/modules/dispatch/baichuan.py b/code/xtuner/model/modules/dispatch/baichuan.py
deleted file mode 100644
index 738c49869882a16bcea06f9efb18e41d8a76d1e8..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/baichuan.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def baichuan2_norm_head_forward(self, hidden_states):
- norm_weight = nn.functional.normalize(self.weight)
- return nn.functional.linear(hidden_states, norm_weight)
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
- cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
- k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
- return q_embed.to(q.dtype), k_embed.to(k.dtype)
-
-
-def baichuan_7b_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- proj = self.W_pack(hidden_states)
- proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(
- 0, -2).squeeze(-2)
- query_states = proj[0].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = proj[1].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- value_states = proj[2].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-def baichuan_13b_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- proj = self.W_pack(hidden_states)
- proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(
- 0, -2).squeeze(-2)
- query_states = proj[0].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = proj[1].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- value_states = proj[2].view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
- if attention_mask is not None:
- if q_len == 1: # inference with cache
- if len(attention_mask.size()) == 4:
- attention_mask = attention_mask[:, :, -1:, :]
- else:
- attention_mask = attention_mask[:, -1:, :]
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/cohere.py b/code/xtuner/model/modules/dispatch/cohere.py
deleted file mode 100644
index 8acf067474409e4f5a7a108b2b86c762c2fad37c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/cohere.py
+++ /dev/null
@@ -1,153 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-import transformers
-from mmengine.utils import digit_version
-from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
-
-from xtuner.parallel.sequence import get_sequence_parallel_world_size
-from xtuner.parallel.sequence.attention import (
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-
-try:
- from transformers.cache_utils import Cache
-except ImportError:
-
- class Cache:
- pass
-
-
-TRANSFORMERS_VERSION = digit_version(transformers.__version__)
-IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43')
-
-if not IS_LOW_VERSION_TRANSFORMERS:
- from transformers.modeling_flash_attention_utils import \
- _flash_attention_forward
-
-
-def cohere_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
-):
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
- if self.use_qk_norm:
- query_states = self.q_norm(query_states)
- key_states = self.k_norm(key_states)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin)
-
- past_key_value = getattr(self, 'past_key_value', past_key_value)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; position_ids needed for
- # the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # TODO: These transpose are quite inefficient but Flash Attention requires
- # the layout [batch_size, sequence_length, num_heads, head_dim].
- # We would need to refactor the KV cache to be able to avoid many of
- # these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- # Ignore copy
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
- # self.num_heads is used in self._upad_input method
- # num_heads has been changed because of sequence parallel
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- if IS_LOW_VERSION_TRANSFORMERS:
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate)
- else:
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
-
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
- self.num_heads = ori_num_head
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/deepseek_v2.py b/code/xtuner/model/modules/dispatch/deepseek_v2.py
deleted file mode 100644
index bfa3ebb6db8c4a7c1bb4e04a004d24e3f774755a..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/deepseek_v2.py
+++ /dev/null
@@ -1,308 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from mmengine import MessageHub
-from transformers.cache_utils import Cache
-
-from xtuner.model.transformers_models.deepseek_v2.modeling_deepseek import \
- apply_rotary_pos_emb
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-from .attention import flash_attn_wo_mask, varlen_flash_attn
-
-
-def deepseek_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- # DeepseekV2FlashAttention2 attention does not support output_attentions
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.q_lora_rank is None:
- q = self.q_proj(hidden_states)
- else:
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
- q_nope, q_pe = torch.split(
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
- compressed_kv, k_pe = torch.split(
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
- kv = (
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
- bsz, q_len, self.num_heads,
- self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
-
- k_nope, value_states = torch.split(
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- kv_seq_len = value_states.shape[-2]
-
- kv_seq_len = value_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None, '`position_ids` should not be None.'
- if self.training:
- cos, sin = self.rotary_emb(
- value_states, seq_len=position_ids.max() + 1)
- else:
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
-
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
- query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
-
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
- key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
-
- if self.q_head_dim != self.v_head_dim:
- value_states = F.pad(value_states,
- [0, self.q_head_dim - self.v_head_dim])
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (DeepseekV2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- # Handle the case where the model is quantized
- if hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- elif torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- else:
- target_dtype = self.q_a_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
- # self.num_heads is used in self._upad_input method
- # num_heads has been changed because of sequence parallel
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate,
- softmax_scale=self.softmax_scale,
- )
-
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
- self.num_heads = ori_num_head
-
- if self.q_head_dim != self.v_head_dim:
- attn_output = attn_output[:, :, :, :self.v_head_dim]
-
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
- self.v_head_dim).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def deepseek_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- is_training = self.training
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
-
- assert is_training == (cumulative_len is not None) == (
- past_key_value is None)
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.q_lora_rank is None:
- q = self.q_proj(hidden_states)
- else:
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
- q_nope, q_pe = torch.split(
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
- compressed_kv, k_pe = torch.split(
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
- kv = (
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
- bsz, q_len, self.num_heads,
- self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
-
- k_nope, value_states = torch.split(
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- kv_seq_len = value_states.shape[-2]
-
- kv_seq_len = value_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None, '`position_ids` should not be None.'
- if self.training:
- cos, sin = self.rotary_emb(
- value_states, seq_len=position_ids.max() + 1)
- else:
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
-
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
- query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
-
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
- key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
-
- if self.q_head_dim != self.v_head_dim:
- value_states = F.pad(value_states,
- [0, self.q_head_dim - self.v_head_dim])
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (DeepseekV2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- # Handle the case where the model is quantized
- if hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- elif torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- else:
- target_dtype = self.q_a_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # ----------------- varlen flash attention forward ----------------------#
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- causal = self.is_causal and q_len != 1
-
- if is_training:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- softmax_scale=self.softmax_scale,
- causal=causal,
- dropout_p=dropout_rate,
- training=True)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- softmax_scale=self.softmax_scale,
- causal=causal,
- dropout_p=dropout_rate,
- training=False)
-
- # ---------------- varlen flash attention forward end ------------------ #
-
- if self.q_head_dim != self.v_head_dim:
- attn_output = attn_output[:, :, :, :self.v_head_dim]
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.num_heads * self.v_head_dim)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/internlm.py b/code/xtuner/model/modules/dispatch/internlm.py
deleted file mode 100644
index 37ca9ad310e056bc357235fa935004da79a3edd7..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/internlm.py
+++ /dev/null
@@ -1,227 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from mmengine import MessageHub
-
-from .triton_kernels import apply_rotary_emb
-
-SUPPORT_FLASH2 = False
-
-try:
- from flash_attn import flash_attn_func, flash_attn_varlen_func
-
- SUPPORT_FLASH2 = True
-except ImportError:
- pass
-
-
-class InternLMRotaryEmbedding(torch.nn.Module):
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None):
- super().__init__()
- self.inv_freq = 1.0 / (
- base**(torch.arange(0, dim, 2).float().to(device) / dim))
-
- # Build here to make `torch.jit.trace` work.
- self.max_seq_len_cached = max_position_embeddings
- t = torch.arange(
- self.max_seq_len_cached,
- device=self.inv_freq.device,
- dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1)
- self.cos_cached = emb.cos()
- self.sin_cached = emb.sin()
-
- def forward(self, x, seq_len):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if (seq_len > self.max_seq_len_cached
- or self.cos_cached.device != x.device
- or self.cos_cached.dtype != x.dtype):
- self.max_seq_len_cached = seq_len
- assert self.inv_freq.dtype == torch.float32
- t = torch.arange(
- self.max_seq_len_cached,
- device=x.device,
- dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device))
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
- self.cos_cached = emb.cos().to(x.dtype)
- self.sin_cached = emb.sin().to(x.dtype)
- return (
- self.cos_cached[:seq_len, ...],
- self.sin_cached[:seq_len, ...],
- )
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def internlm_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(
- 1, 2)
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(
- 1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(
- 1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- if SUPPORT_FLASH2:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
- attn_output = attn_output.contiguous()
- else:
- # use flash attention implemented by pytorch
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
-
-
-def internlm_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- # Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
- use_varlen_atten = (cumulative_len is not None)
-
- bsz, q_len, _ = hidden_states.size()
- assert bsz == 1, (f'If utilizing local attention, the batch size should be'
- f' set to 1, but got {bsz}')
-
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim)
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim)
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
- self.head_dim)
-
- kv_seq_len = key_states.shape[-3]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- if use_varlen_atten:
- cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states,
- cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- else:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- cos, sin = self.rotary_emb(value_states, kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- assert SUPPORT_FLASH2
- if use_varlen_atten:
- q_unpad, k_unpad, v_unpad = query_states.flatten(
- 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
- attn_output = flash_attn_varlen_func(
- q_unpad,
- k_unpad,
- v_unpad,
- cumulative_len,
- cumulative_len,
- max_seqlen,
- max_seqlen,
- 0,
- return_attn_probs=False,
- causal=True,
- )
- else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/internlm2.py b/code/xtuner/model/modules/dispatch/internlm2.py
deleted file mode 100644
index 7c601f0dc66c056c979a84efbb18b9125cfb44cf..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/internlm2.py
+++ /dev/null
@@ -1,306 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-from einops import rearrange
-from mmengine import MessageHub
-from transformers.cache_utils import Cache, StaticCache
-
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
- to (batch, seqlen, num_attention_heads, head_dim)"""
- batch, slen, num_key_value_heads, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, :,
- None, :].expand(batch, slen,
- num_key_value_heads, n_rep,
- head_dim)
- return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
- head_dim)
-
-
-def internlm2_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-):
- if isinstance(past_key_value, StaticCache):
- raise ValueError(
- '`static` cache implementation is not compatible with '
- '`attn_implementation==flash_attention_2` make sure to use `sdpa` '
- 'in the mean time, and open an issue at '
- 'https://github.com/huggingface/transformers')
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models;
- # cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.wqkv.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
- # self.num_heads is used in self._upad_input method
- # num_heads has been changed because of sequence parallel
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- dropout_rate = 0.0
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate)
-
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
- self.num_heads = ori_num_head
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def internlm2_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
-
- if isinstance(past_key_value, StaticCache):
- raise ValueError(
- '`static` cache implementation is not compatible with '
- '`attn_implementation==flash_attention_2` make sure to use `sdpa` '
- 'in the mean time, and open an issue at '
- 'https://github.com/huggingface/transformers')
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
- use_varlen_atten = (cumulative_len is not None)
-
- bsz, q_len, _ = hidden_states.size()
-
- assert bsz == 1, (f'If utilizing local attention, the batch size should be'
- f' set to 1, but got {bsz}')
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- 'b q (h gs d) -> b q h gs d',
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., :self.num_key_value_groups, :]
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- try:
- cos, sin = self.rotary_emb(value_states, position_ids)
- except RuntimeError:
- raise RuntimeError(
- 'You are using the old version of InternLM2 model. The '
- '`modeling_internlm2.py` is outdated. Please update the InternLM2 '
- 'model.')
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models;
- # cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.wqkv.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
-
- assert SUPPORT_FLASH2
-
- dropout_rate = 0.0
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=True,
- dropout_p=dropout_rate,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=True,
- dropout_p=dropout_rate,
- training=self.training)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.wo(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/internlm3.py b/code/xtuner/model/modules/dispatch/internlm3.py
deleted file mode 100644
index 8ac15316b751060cd90f9d443d628a68be68fb32..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/internlm3.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import Callable, Optional, Tuple
-
-import torch
-import torch.distributed as dist
-from mmengine import MessageHub
-from transformers.cache_utils import Cache
-from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-from transformers.models.llama.modeling_llama import (
- apply_rotary_pos_emb,
- eager_attention_forward,
- repeat_kv,
-)
-from transformers.processing_utils import Unpack
-
-from xtuner.parallel.sequence import get_sequence_parallel_world_size
-from xtuner.parallel.sequence.attention import (
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn,
-)
-
-
-def internlm3_attn_forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_value: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
-):
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
-
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
-
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed
- # for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs
- )
-
- # different from LlamaAttention.forward
- # repeat k/v heads if n_kv_heads < n_heads for sequence parallel
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- enable_sequence_parallel = (
- dist.is_initialized()
- and get_sequence_parallel_world_size() > 1
- and self.training
- )
- if enable_sequence_parallel:
- # Reashape for `pre_process_for_sequence_parallel_attn`
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- query_states, key_states, value_states = pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states
- )
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- # different places end
-
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- if self.config._attn_implementation == "sdpa" and kwargs.get(
- "output_attentions", False
- ):
- warnings.warn(
- "`torch.nn.functional.scaled_dot_product_attention` does not "
- "support `output_attentions=True`. Falling back to eager "
- "attention. This warning can be removed using the argument"
- ' `attn_implementation="eager"` when loading the model.'
- )
- else:
- attention_interface = ALL_ATTENTION_FUNCTIONS[
- self.config._attn_implementation
- ]
-
- message_hub = MessageHub.get_instance("varlen_attn_args")
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f"cumulative_len_rank_{rank}")
- use_varlen_atten = cumulative_len is not None
- if use_varlen_atten:
- # When gradient_checkpointing is enabled, the flash_attn_kwargs
- # parameter is not automatically passed to the model. In such
- # cases, parameters like cu_seq_lens_q and max_length_q are
- # computed based on position_ids. However, when sequence
- # parallel is enabled, position_ids is split along the
- # sequence length, leading to incorrect calculations of these
- # parameters.
- # To address this issue, it is necessary to manually provide
- # the flash_attn_kwargs parameters.
- max_seqlen = message_hub.get_info(f"max_seqlen_rank_{rank}")
- kwargs["cu_seq_lens_q"] = cumulative_len
- kwargs["cu_seq_lens_k"] = cumulative_len
- kwargs["max_length_q"] = max_seqlen
- kwargs["max_length_k"] = max_seqlen
- kwargs.pop("position_ids", None)
-
- # Hacky: `sdpa_attention_forward` does repeat_kv based on
- # module.num_key_value_groups but it is done before
- num_key_value_groups = self.num_key_value_groups
- self.num_key_value_groups = 1
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- self.num_key_value_groups = num_key_value_groups
-
- # different from LlamaAttention.forward
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
-
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
diff --git a/code/xtuner/model/modules/dispatch/llama.py b/code/xtuner/model/modules/dispatch/llama.py
deleted file mode 100644
index 8132096fd484f43535543ed8f6de3efe36491c7b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/llama.py
+++ /dev/null
@@ -1,524 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import warnings
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-from mmengine import MessageHub
-from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb,
- repeat_kv)
-from transformers.utils import is_flash_attn_greater_or_equal_2_10
-
-from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask,
- varlen_flash_attn)
-from .triton_kernels import apply_rotary_emb
-
-try:
- from transformers.cache_utils import Cache
-except ImportError:
-
- class Cache:
- pass
-
-
-def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
- to (batch, seqlen, num_attention_heads, head_dim)"""
- batch, slen, num_key_value_heads, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, :,
- None, :].expand(batch, slen,
- num_key_value_heads, n_rep,
- head_dim)
- return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
- head_dim)
-
-
-def llama_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
-):
- # Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin)
-
- past_key_value = getattr(self, 'past_key_value', past_key_value)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models;
- # cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- assert SUPPORT_FLASH2
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- if is_flash_attn_greater_or_equal_2_10():
- causal = self.is_causal
- else:
- # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm
- # is bumped to 2.1. For details, please see the comment in
- # LlamaFlashAttention2 __init__.
- causal = self.is_causal and q_len != 1
-
- # the shape of attention_mask used by flash_attn and
- # F.scaled_dot_product_attention are different
- assert attention_mask is None or attention_mask.ndim == 2, \
- ('When using flash_attn, attention_mask.ndim should equal to 2.'
- f'But got attention_mask.shape = {attention_mask.shape}.'
- 'We can pass the `attn_implementation="flash_attention_2"` flag '
- 'to `.from_pretrained` method when instantiating a Internlm2 '
- 'model.')
-
- if attention_mask is not None:
- attn_output = flash_attn_w_mask(
- query_states,
- key_states,
- value_states,
- attention_mask,
- causal=causal,
- dropout_p=dropout_rate,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=causal,
- dropout_p=dropout_rate,
- training=self.training)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def llama_attn_forward_legacy(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
- assert position_ids is not None
- if self.training:
- cos, sin = self.rotary_emb(
- value_states, seq_len=position_ids.max() + 1)
- else:
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- assert SUPPORT_FLASH2
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- if is_flash_attn_greater_or_equal_2_10():
- causal = self.is_causal
- else:
- # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm
- # is bumped to 2.1. For details, please see the comment in
- # LlamaFlashAttention2 __init__.
- causal = self.is_causal and q_len != 1
-
- # the shape of attention_mask used by flash_attn and
- # F.scaled_dot_product_attention are different
- assert attention_mask is None or attention_mask.ndim == 2, \
- ('When using flash_attn, attention_mask.ndim should equal to 2.'
- f'But got attention_mask.shape = {attention_mask.shape}.'
- 'We can pass the `attn_implementation="flash_attention_2"` flag '
- 'to `.from_pretrained` method when instantiating a Internlm2 '
- 'model.')
-
- if attention_mask is not None:
- attn_output = flash_attn_w_mask(
- query_states,
- key_states,
- value_states,
- attention_mask=attention_mask,
- causal=causal,
- dropout_p=dropout_rate,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=causal,
- dropout_p=dropout_rate,
- training=self.training)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
-
-
-def llama_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
- use_varlen_atten = (cumulative_len is not None)
-
- if 'padding_mask' in kwargs:
- warnings.warn('Passing `padding_mask` is deprecated and will be '
- 'removed in v4.37. Please make sure use '
- '`attention_mask` instead.`')
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin)
-
- past_key_value = getattr(self, 'past_key_value', past_key_value)
-
- if past_key_value is not None:
- # sin and cos are specific to RoPE models;
- # cache_position needed for the static cache
- cache_kwargs = {
- 'sin': sin,
- 'cos': cos,
- 'cache_position': cache_position
- }
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently casted
- # in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- assert SUPPORT_FLASH2
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=True,
- dropout_p=dropout_rate,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=True,
- training=self.training)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-def llama_varlen_attn_forward_legacy(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
- use_varlen_atten = (cumulative_len is not None)
-
- if 'padding_mask' in kwargs:
- warnings.warn('Passing `padding_mask` is deprecated and will be '
- 'removed in v4.37. Please make sure use '
- '`attention_mask` instead.`')
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
-
- kv_seq_len = key_states.shape[-3]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- if use_varlen_atten:
- cos, sin = self.rotary_emb(value_states, max_seqlen)
- # position_ids (1, seq_len)
- # cos, sin (1, seq_len, dim) -> (seq_len, dim)
- cos = cos[position_ids].squeeze(0)
- sin = sin[position_ids].squeeze(0)
- query_states = apply_rotary_emb(query_states, cos, sin)
- key_states = apply_rotary_emb(key_states, cos, sin)
- else:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- cos, sin = self.rotary_emb(value_states, kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently casted
- # in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- assert SUPPORT_FLASH2
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=True,
- dropout_p=dropout_rate,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=True,
- dropout_p=dropout_rate,
- training=self.training)
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/mistral.py b/code/xtuner/model/modules/dispatch/mistral.py
deleted file mode 100644
index dc6c7fed827f229aeb286a35d2b290126f07e965..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/mistral.py
+++ /dev/null
@@ -1,447 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import inspect
-import warnings
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import transformers
-from mmengine import MessageHub
-from mmengine.utils import digit_version
-from transformers.cache_utils import Cache
-from transformers.models.mistral.modeling_mistral import (apply_rotary_pos_emb,
- repeat_kv)
-
-from xtuner.parallel.sequence import get_sequence_parallel_world_size
-from xtuner.parallel.sequence.attention import (
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-from .attention import flash_attn_wo_mask, varlen_flash_attn
-from .triton_kernels import apply_rotary_emb
-
-SUPPORT_FLASH2 = False
-
-try:
- from flash_attn import flash_attn_func
- _flash_supports_window_size = 'window_size' in list(
- inspect.signature(flash_attn_func).parameters)
- SUPPORT_FLASH2 = True
-except ImportError:
- pass
-
-TRANSFORMERS_VERSION = digit_version(transformers.__version__)
-IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43')
-
-if not IS_LOW_VERSION_TRANSFORMERS:
- from transformers.modeling_flash_attention_utils import \
- _flash_attention_forward
-
-
-class MistralRotaryEmbedding(nn.Module):
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- self.inv_freq = 1.0 / (
- base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings,
- device=self.inv_freq.device,
- dtype=torch.get_default_dtype())
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
- freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(device))
- # Different from paper, but it uses a different permutation
- # in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
- self.cos_cached = emb.cos().to(dtype)
- self.sin_cached = emb.sin().to(dtype)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if (seq_len > self.max_seq_len_cached
- or self.cos_cached.device != x.device # noqa: W503
- or self.cos_cached.dtype != x.dtype): # noqa: W503
- self._set_cos_sin_cache(
- seq_len=seq_len, device=x.device, dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
- to (batch, seqlen, num_attention_heads, head_dim)"""
- batch, slen, num_key_value_heads, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, :,
- None, :].expand(batch, slen,
- num_key_value_heads, n_rep,
- head_dim)
- return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
- head_dim)
-
-
-def mistral_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None
- if self.training:
- cos, sin = self.rotary_emb(
- value_states, seq_len=position_ids.max() + 1)
- else:
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window)
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads for sequence parallel
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
- # num_heads has been changed because of sequence parallel
- # `self.num_heads`` is not used in self._flash_attention_forward
- # in mistral/mixtral, we are doing this to avoid some unnecessary risk
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- if IS_LOW_VERSION_TRANSFORMERS:
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length=query_states.shape[1],
- dropout=dropout_rate,
- use_sliding_windows=use_sliding_windows,
- )
- else:
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate,
- sliding_window=getattr(self.config, 'sliding_window', None),
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
-
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
- self.num_heads = ori_num_head
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def mistral_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- is_training = self.training
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
-
- assert is_training == (past_key_value is None)
- use_varlen_atten = (cumulative_len is not None)
-
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37'
- ' Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
- bsz, q_len, _ = hidden_states.size()
- assert bsz == 1, (f'If utilizing local attention, the batch size should be'
- f' set to 1, but got {bsz}')
- # attention_mask is set to None if no padding token in input_ids
- assert attention_mask is None
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
-
- assert _flash_supports_window_size, \
- ('The current flash attention version does not support sliding window '
- 'attention, for a more memory efficient implementation make sure '
- 'to upgrade flash-attn library.')
-
- kv_seq_len = key_states.shape[-3]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- if use_varlen_atten:
- cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states,
- cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- else:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- # Because the input can be padded, the absolute sequence length
- # depends on the max position id.
- rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window # noqa: W503
- and cache_has_contents): # noqa: W503
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for
- # training stability reasons, therefore the input hidden states gets
- # silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # ----------------- flash attention forward ------------------------#
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- causal = self.is_causal and q_len != 1
-
- use_sliding_windows = (
- _flash_supports_window_size and # noqa: W504
- getattr(self.config, 'sliding_window', None) is not None # noqa: W503
- and kv_seq_len > self.config.sliding_window) # noqa: W503
- window_size = (self.config.sliding_window,
- self.config.sliding_window) if use_sliding_windows else (-1,
- -1)
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=causal,
- dropout_p=dropout_rate,
- window_size=window_size,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=causal,
- dropout_p=dropout_rate,
- window_size=window_size,
- training=self.training)
-
- # ---------------- flash attention forward end ------------------- #
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/phi3.py b/code/xtuner/model/modules/dispatch/phi3.py
deleted file mode 100644
index 10f60f93983392643f3c1907b34af1bd48b2f03c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/phi3.py
+++ /dev/null
@@ -1,480 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import inspect
-import warnings
-from typing import Optional, Tuple
-
-import torch
-import torch.distributed as dist
-import transformers
-from mmengine import MessageHub
-from mmengine.utils import digit_version
-
-from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-from .attention import flash_attn_wo_mask, varlen_flash_attn
-
-try:
- from transformers.cache_utils import Cache
-except ImportError:
-
- class Cache:
- pass
-
-
-TRANSFORMERS_VERSION = digit_version(transformers.__version__)
-IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43')
-
-if not IS_LOW_VERSION_TRANSFORMERS:
- from transformers.modeling_flash_attention_utils import \
- _flash_attention_forward
-
-_flash_supports_window_size = False
-try:
- from flash_attn import flash_attn_func
-
- _flash_supports_window_size = 'window_size' in list(
- inspect.signature(flash_attn_func).parameters)
-
- if not _flash_supports_window_size:
- raise ValueError(
- 'Please update flash-attention to support window size.')
-# else:
-except ImportError:
- pass
-
-
-# Copied from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/3a811845d89f3c1b3f41b341d0f9f05104769f35/modeling_phi3.py#L302 # noqa:E501
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/3a811845d89f3c1b3f41b341d0f9f05104769f35/modeling_phi3.py#L247 # noqa:E501
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/3a811845d89f3c1b3f41b341d0f9f05104769f35/modeling_phi3.py#L255 # noqa:E501
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """ # noqa:E501
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def phi3_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
-):
- if not _flash_supports_window_size:
- raise ValueError(
- 'The current flash attention version does not support '
- 'sliding window attention.')
-
- output_attentions = False
-
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv = self.qkv_proj(hidden_states)
- query_pos = self.num_heads * self.head_dim
- query_states = qkv[..., :query_pos]
- key_states = qkv[..., query_pos:query_pos +
- self.num_key_value_heads * self.head_dim]
- value_states = qkv[...,
- query_pos + self.num_key_value_heads * self.head_dim:]
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
- cos, sin = self.rotary_emb(
- value_states, position_ids, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window)
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_dropout = self.attention_dropout if self.training else 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32.
-
- if query_states.dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.qkv_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim)
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states,
- scatter_dim=2, gather_dim=1)
- # num_heads has been changed because of sequence parallel
- # `self.num_heads`` is not used in self._flash_attention_forward
- # in mistral/mixtral, we are doing this to avoid some unnecessary risk
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- if IS_LOW_VERSION_TRANSFORMERS:
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=attn_dropout,
- use_sliding_windows=use_sliding_windows,
- )
- else:
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=attn_dropout,
- sliding_window=getattr(self.config, 'sliding_window', None),
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- is_causal=self.is_causal,
- )
-
- if enable_sequence_parallel:
- # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
- attn_output = post_process_for_sequence_parallel_attn(
- attn_output, scatter_dim=1, gather_dim=2)
- self.num_heads = ori_num_head
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def phi3_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if not _flash_supports_window_size:
- raise ValueError(
- 'The current flash attention version does not support '
- 'sliding window attention.')
-
- output_attentions = False
-
- is_training = self.training
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
-
- assert is_training == (past_key_value is None)
- use_varlen_atten = (cumulative_len is not None)
-
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37'
- ' Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
-
- bsz, q_len, _ = hidden_states.size()
- assert bsz == 1, (f'If utilizing local attention, the batch size should be'
- f' set to 1, but got {bsz}')
- # attention_mask is set to None if no padding token in input_ids
- # varlen attn need data packing so no padding tokens in input_ids
- assert attention_mask is None
-
- qkv = self.qkv_proj(hidden_states)
- query_pos = self.num_heads * self.head_dim
- query_states = qkv[..., :query_pos]
- key_states = qkv[..., query_pos:query_pos +
- self.num_key_value_heads * self.head_dim]
- value_states = qkv[...,
- query_pos + self.num_key_value_heads * self.head_dim:]
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None
- rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
- cos, sin = self.rotary_emb(
- value_states, position_ids, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window)
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- # In PEFT, usually we cast the layer norms in float32 for
- # training stability reasons, therefore the input hidden states gets
- # silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
-
- if query_states.dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.qkv_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # ----------------- flash attention forward ------------------------#
-
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- causal = self.is_causal and q_len != 1
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window)
-
- window_size = (self.config.sliding_window,
- self.config.sliding_window) if use_sliding_windows else (-1,
- -1)
- attn_dropout = self.attention_dropout if self.training else 0.0
-
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=causal,
- dropout_p=attn_dropout,
- window_size=window_size,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=causal,
- dropout_p=attn_dropout,
- window_size=window_size,
- training=self.training)
-
- # ---------------- flash attention forward end ------------------- #
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/qwen2.py b/code/xtuner/model/modules/dispatch/qwen2.py
deleted file mode 100644
index 20f2f40f382e4e88daf7b40a54611d9b781460a9..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/qwen2.py
+++ /dev/null
@@ -1,380 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import inspect
-import warnings
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-import transformers
-from mmengine import MessageHub
-from mmengine.utils import digit_version
-from transformers.cache_utils import Cache
-from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb,
- repeat_kv)
-
-from xtuner.parallel.sequence import get_sequence_parallel_world_size
-from xtuner.parallel.sequence.attention import (
- post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn)
-from .attention import flash_attn_wo_mask, varlen_flash_attn
-
-SUPPORT_FLASH2 = False
-
-try:
- from flash_attn import flash_attn_func
- _flash_supports_window_size = 'window_size' in list(
- inspect.signature(flash_attn_func).parameters)
- SUPPORT_FLASH2 = True
-except ImportError:
- pass
-
-TRANSFORMERS_VERSION = digit_version(transformers.__version__)
-IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43')
-
-if not IS_LOW_VERSION_TRANSFORMERS:
- from transformers.modeling_flash_attention_utils import \
- _flash_attention_forward
-
-
-def qwen2_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None
- rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and self.config.use_sliding_window)
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads for sequence parallel
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons therefore the input hidden states gets silently
- # casted in float32. Hence, we need cast them back in the correct dtype
- # just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not
- # cast the LayerNorms in fp32.
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and self.training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
- # num_heads has been changed because of sequence parallel
- # `self.num_heads`` is not used in self._flash_attention_forward
- # in mistral/mixtral, we are doing this to avoid some unnecessary risk
- ori_num_head = self.num_heads
- self.num_heads = query_states.shape[-2]
-
- if IS_LOW_VERSION_TRANSFORMERS:
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length=query_states.shape[1],
- dropout=dropout_rate,
- use_sliding_windows=use_sliding_windows,
- )
- else:
- if (self.config.use_sliding_window
- and getattr(self.config, 'sliding_window', None) is not None
- and self.layer_idx >= self.config.max_window_layers):
- # There may be bugs here, but we are aligned with Transformers
- sliding_window = self.config.sliding_window
- else:
- sliding_window = None
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_states.shape[1],
- dropout=dropout_rate,
- sliding_window=sliding_window,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- if enable_sequence_parallel:
- attn_output = post_process_for_sequence_parallel_attn(attn_output)
- self.num_heads = ori_num_head
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-def qwen2_varlen_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-):
- is_training = self.training
-
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
-
- assert is_training == (past_key_value is None)
- use_varlen_atten = (cumulative_len is not None)
-
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37'
- ' Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
- self.layer_idx)
-
- assert position_ids is not None
- rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1)
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value
- # `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- 'past key must have a shape of (`batch_size, num_heads, '
- 'self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat(
- [attention_mask,
- torch.ones_like(attention_mask[:, -1:])],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads for sequence parallel
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for
- # training stability reasons, therefore the input hidden states gets
- # silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # ----------------- flash attention forward ------------------------#
-
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- causal = self.is_causal and q_len != 1
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and self.config.use_sliding_window)
- # Decide whether to use SWA or not by layer index.
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
- use_sliding_windows = False
-
- window_size = (self.config.sliding_window,
- self.config.sliding_window) if use_sliding_windows else (-1,
- -1)
-
- if use_varlen_atten:
- attn_output = varlen_flash_attn(
- query_states,
- key_states,
- value_states,
- cumulative_len,
- max_seqlen,
- causal=causal,
- dropout_p=dropout_rate,
- window_size=window_size,
- training=self.training)
- else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal=causal,
- dropout_p=dropout_rate,
- window_size=window_size,
- training=self.training)
-
- # ---------------- flash attention forward end ------------------- #
-
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
diff --git a/code/xtuner/model/modules/dispatch/triton_kernels/__init__.py b/code/xtuner/model/modules/dispatch/triton_kernels/__init__.py
deleted file mode 100644
index ed29f409f853172a0c90f0e81b0200972c379e66..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/triton_kernels/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .layer_norm import layer_norm_forward
-from .rms_norm import rms_norm_forward
-from .rotary import apply_rotary_emb
-
-__all__ = ['rms_norm_forward', 'layer_norm_forward', 'apply_rotary_emb']
diff --git a/code/xtuner/model/modules/dispatch/triton_kernels/layer_norm.py b/code/xtuner/model/modules/dispatch/triton_kernels/layer_norm.py
deleted file mode 100644
index f808d6ad157a3ddbfeb6df02960c79739fcdc088..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/triton_kernels/layer_norm.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.nn.functional as F
-
-
-def layer_norm_forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- hidden_states = F.layer_norm(
- hidden_states, (hidden_states.shape[-1], ), eps=self.variance_epsilon)
- hidden_states = self.weight.to(torch.float32) * hidden_states
- return hidden_states.to(input_dtype)
diff --git a/code/xtuner/model/modules/dispatch/triton_kernels/rms_norm.py b/code/xtuner/model/modules/dispatch/triton_kernels/rms_norm.py
deleted file mode 100644
index 6191d55ba6e5e983d1e20c3e5282dffd439d2fd6..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/triton_kernels/rms_norm.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import triton
-import triton.language as tl
-
-
-@triton.jit
-def _rms_norm_fwd_fused(
- X, # pointer to the input
- Y, # pointer to the output
- W, # pointer to the weights
- Rstd, # pointer to the 1/std
- stride, # how much to increase the pointer when moving by 1 row
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- BLOCK_SIZE: tl.constexpr,
-):
- # Map the program id to the row of X and Y it should compute.
- row = tl.program_id(0)
- Y += row * stride
- X += row * stride
- # Compute variance
- _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
- for off in range(0, N, BLOCK_SIZE):
- cols = off + tl.arange(0, BLOCK_SIZE)
- x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
- _var += x * x
- var = tl.sum(_var, axis=0) / N
- rstd = 1 / tl.sqrt(var + eps)
- # Write rstd
- tl.store(Rstd + row, rstd)
- # Normalize and apply linear transformation
- for off in range(0, N, BLOCK_SIZE):
- cols = off + tl.arange(0, BLOCK_SIZE)
- mask = cols < N
- w = tl.load(W + cols, mask=mask)
- x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
- x_hat = x * rstd
- y = x_hat * w
- # Write output
- tl.store(Y + cols, y, mask=mask)
-
-
-@triton.jit
-def _rms_norm_bwd_dx_fused(
- DX, # pointer to the input gradient
- DY, # pointer to the output gradient
- DW, # pointer to the partial sum of weights gradient
- X, # pointer to the input
- W, # pointer to the weights
- Rstd, # pointer to the 1/std
- Lock, # pointer to the lock
- stride, # how much to increase the pointer when moving by 1 row
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- GROUP_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr):
- # Map the program id to the elements of X, DX, and DY it should compute.
- row = tl.program_id(0)
- cols = tl.arange(0, BLOCK_SIZE_N)
- mask = cols < N
- X += row * stride
- DY += row * stride
- DX += row * stride
- # Offset locks and weights/biases gradient pointer for parallel reduction
- lock_id = row % GROUP_SIZE_M
- Lock += lock_id
- Count = Lock + GROUP_SIZE_M
- DW = DW + lock_id * N + cols
- # Load data to SRAM
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
- w = tl.load(W + cols, mask=mask).to(tl.float32)
- rstd = tl.load(Rstd + row)
- # Compute dx
- xhat = x * rstd
- wdy = w * dy
- xhat = tl.where(mask, xhat, 0.)
- wdy = tl.where(mask, wdy, 0.)
- c1 = tl.sum(xhat * wdy, axis=0) / N
- dx = (wdy - (xhat * c1)) * rstd
- # Write dx
- tl.store(DX + cols, dx, mask=mask)
- # Accumulate partial sums for dw/db
- partial_dw = (dy * xhat).to(w.dtype)
- while tl.atomic_cas(Lock, 0, 1) == 1:
- pass
- count = tl.load(Count)
- # First store doesn't accumulate
- if count == 0:
- tl.atomic_xchg(Count, 1)
- else:
- partial_dw += tl.load(DW, mask=mask)
- tl.store(DW, partial_dw, mask=mask)
- # Release the lock
- tl.atomic_xchg(Lock, 0)
-
-
-@triton.jit
-def _rms_norm_bwd_dwdb(
- DW, # pointer to the partial sum of weights gradient
- FINAL_DW, # pointer to the weights gradient
- M, # GROUP_SIZE_M
- N, # number of columns
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr):
- # Map the program id to the elements of DW and DB it should compute.
- pid = tl.program_id(0)
- cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- # Iterate through the rows of DW and DB to sum the partial sums.
- for i in range(0, M, BLOCK_SIZE_M):
- rows = i + tl.arange(0, BLOCK_SIZE_M)
- mask = (rows[:, None] < M) & (cols[None, :] < N)
- offs = rows[:, None] * N + cols[None, :]
- dw += tl.load(DW + offs, mask=mask, other=0.)
- # Write the final sum to the output.
- sum_dw = tl.sum(dw, axis=0)
- tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
-
-
-class RMSNorm(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, x, weight, eps):
- # allocate output
- y = torch.empty_like(x)
- # reshape input data into 2D tensor
- x_arg = x.reshape(-1, x.shape[-1])
- M, N = x_arg.shape
- rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
- # Less than 64KB per feature: enqueue fused kernel
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
- if N > BLOCK_SIZE:
- raise RuntimeError(
- "This rms norm doesn't support feature dim >= 64KB.")
- # heuristics for number of warps
- num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
- # enqueue kernel
- _rms_norm_fwd_fused[(M, )](
- x_arg,
- y,
- weight,
- rstd,
- x_arg.stride(0),
- N,
- eps,
- BLOCK_SIZE=BLOCK_SIZE,
- num_warps=num_warps,
- )
- ctx.save_for_backward(x, weight, rstd)
- ctx.BLOCK_SIZE = BLOCK_SIZE
- ctx.num_warps = num_warps
- ctx.eps = eps
- return y
-
- @staticmethod
- def backward(ctx, dy):
- x, w, v = ctx.saved_tensors
- # heuristics for amount of parallel reduction stream for DW/DB
- N = w.shape[0]
- GROUP_SIZE_M = 64
- if N <= 8192:
- GROUP_SIZE_M = 96
- if N <= 4096:
- GROUP_SIZE_M = 128
- if N <= 1024:
- GROUP_SIZE_M = 256
- # allocate output
- locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda')
- _dw = torch.empty((GROUP_SIZE_M, w.shape[0]),
- dtype=x.dtype,
- device=w.device)
- dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device)
- dx = torch.empty_like(dy)
- # enqueue kernel using forward pass heuristics
- # also compute partial sums for DW and DB
- x_arg = x.reshape(-1, x.shape[-1])
- M, N = x_arg.shape
- _rms_norm_bwd_dx_fused[(M, )](
- dx,
- dy,
- _dw,
- x,
- w,
- v,
- locks,
- x_arg.stride(0),
- N,
- ctx.eps,
- BLOCK_SIZE_N=ctx.BLOCK_SIZE,
- GROUP_SIZE_M=GROUP_SIZE_M,
- num_warps=ctx.num_warps)
-
- def grid(meta):
- return [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
-
- # accumulate partial sums in separate kernel
- _rms_norm_bwd_dwdb[grid](
- _dw,
- dw,
- GROUP_SIZE_M,
- N,
- BLOCK_SIZE_M=32,
- BLOCK_SIZE_N=128,
- )
- return dx, dw, None
-
-
-rms_norm = RMSNorm.apply
-
-
-def rms_norm_forward(self, hidden_states):
- if (hidden_states.device == torch.device('cpu')
- or self.weight.device == torch.device('cpu')):
- raise RuntimeError(
- 'Can not use triton kernels on cpu. Please set `USE_TRITON_KERNEL`'
- ' environment variable to 0 before training.')
- return rms_norm(hidden_states, self.weight, self.variance_epsilon)
diff --git a/code/xtuner/model/modules/dispatch/triton_kernels/rotary.py b/code/xtuner/model/modules/dispatch/triton_kernels/rotary.py
deleted file mode 100644
index 1e09c16628751dbc769d1ca4ce7d0650de8f835b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/triton_kernels/rotary.py
+++ /dev/null
@@ -1,327 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-# Modified from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py # noqa:E501
-from typing import Optional, Union
-
-import torch
-import triton
-import triton.language as tl
-
-
-@triton.jit
-def rotary_kernel(
- OUT, # Pointers to matrices
- X,
- COS,
- SIN,
- CU_SEQLENS,
- SEQLEN_OFFSETS, # this could be int or a pointer
- # Matrix dimensions
- seqlen,
- rotary_dim,
- seqlen_ro,
- # strides
- stride_out_batch,
- stride_out_seqlen,
- stride_out_nheads,
- stride_out_headdim,
- stride_x_batch,
- stride_x_seqlen,
- stride_x_nheads,
- stride_x_headdim,
- # Meta-parameters
- BLOCK_K: tl.constexpr,
- IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
- IS_VARLEN: tl.constexpr,
- INTERLEAVED: tl.constexpr,
- CONJUGATE: tl.constexpr,
- BLOCK_M: tl.constexpr,
-):
- pid_m = tl.program_id(axis=0)
- pid_batch = tl.program_id(axis=1)
- pid_head = tl.program_id(axis=2)
- rotary_dim_half = rotary_dim // 2
-
- if not IS_VARLEN:
- X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
- OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
- else:
- start_idx = tl.load(CU_SEQLENS + pid_batch)
- seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
- X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
- OUT = OUT + start_idx * stride_out_seqlen + \
- pid_head * stride_out_nheads
-
- if pid_m * BLOCK_M >= seqlen:
- return
- rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- if not IS_SEQLEN_OFFSETS_TENSOR:
- rm_cs = rm + SEQLEN_OFFSETS
- else:
- rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
- rk = tl.arange(0, BLOCK_K)
- rk_half = tl.arange(0, BLOCK_K // 2)
-
- if not INTERLEAVED:
- # Load the 1st and 2nd halves of X, do calculation,
- # then store to 1st and 2nd halves of OUT
- X = X + (
- rm[:, None] * stride_x_seqlen +
- rk_half[None, :] * stride_x_headdim)
- # This is different from the official implementation as the shapes of
- # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of
- # (seqlen_ro, rotary_dim // 2).
- COS = COS + (rm_cs[:, None] * rotary_dim + rk_half[None, :])
- SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_half[None, :])
- cos = tl.load(
- COS,
- mask=(rm_cs[:, None] < seqlen_ro) &
- (rk_half[None, :] < rotary_dim_half),
- other=1.0).to(tl.float32)
- sin = tl.load(
- SIN,
- mask=(rm_cs[:, None] < seqlen_ro) &
- (rk_half[None, :] < rotary_dim_half),
- other=0.0).to(tl.float32)
- x0 = tl.load(
- X,
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
- other=0.0).to(tl.float32)
- x1 = tl.load(
- X + rotary_dim_half * stride_x_headdim,
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
- other=0.0,
- ).to(tl.float32)
- if CONJUGATE:
- sin = -sin
- o0 = x0 * cos - x1 * sin
- o1 = x0 * sin + x1 * cos
- # write back result
- OUT = OUT + (
- rm[:, None] * stride_out_seqlen +
- rk_half[None, :] * stride_out_headdim)
- tl.store(
- OUT,
- o0,
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
- tl.store(
- OUT + rotary_dim_half * stride_out_headdim,
- o1,
- mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
- )
- else:
- # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately
- # since both are slow.
- # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
- # Loading x0 will be fast but x1 will be slow.
- # Then we load cos = COS[0, 0, 1, 1, ...] and
- # sin = SIN[0, 0, 1, 1, ...].
- # Then we do the calculation and use tl.where to pick put the right
- # outputs for the even and for the odd indices.
- rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
- rk_repeat = tl.arange(0, BLOCK_K) // 2
- # This is different from the official implementation as the shapes of
- # the two tensors cos and sin are (seqlen_ro, rotary_dim) instead of
- # (seqlen_ro, rotary_dim // 2).
- X0 = X + (
- rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
- X1 = X + (
- rm[:, None] * stride_x_seqlen +
- rk_swap[None, :] * stride_x_headdim)
- COS = COS + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :])
- SIN = SIN + (rm_cs[:, None] * rotary_dim + rk_repeat[None, :])
- cos = tl.load(
- COS,
- mask=(rm_cs[:, None] < seqlen_ro) &
- (rk_repeat[None, :] < rotary_dim_half),
- other=1.0,
- ).to(tl.float32)
- sin = tl.load(
- SIN,
- mask=(rm_cs[:, None] < seqlen_ro) &
- (rk_repeat[None, :] < rotary_dim_half),
- other=0.0,
- ).to(tl.float32)
- x0 = tl.load(
- X0,
- mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim),
- other=0.0).to(tl.float32)
- x1 = tl.load(
- X1,
- mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim),
- other=0.0).to(tl.float32)
- if CONJUGATE:
- sin = -sin
- x0_cos = x0 * cos
- x1_sin = x1 * sin
- out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
- OUT = OUT + (
- rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
- tl.store(
- OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
-
-
-def apply_rotary(
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- seqlen_offsets: Union[int, torch.Tensor] = 0,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- interleaved=False,
- inplace=False,
- conjugate=False,
-) -> torch.Tensor:
- """
- Arguments:
- x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
- else (total_seqlen, nheads, headdim).
- cos: (seqlen_ro, rotary_dim)
- sin: (seqlen_ro, rotary_dim)
- seqlen_offsets: integer or integer tensor of size (batch,)
- cu_seqlens: (batch + 1,) or None
- max_seqlen: int
- Returns:
- y: (batch, seqlen, nheads, headdim)
- """
- is_varlen = cu_seqlens is not None
- if not is_varlen:
- batch, seqlen, nheads, headdim = x.shape
- else:
- assert max_seqlen is not None, ('If cu_seqlens is passed in, '
- 'then max_seqlen must be passed')
- total_seqlen, nheads, headdim = x.shape
- batch_p_1 = cu_seqlens.shape[0]
- batch = batch_p_1 - 1
- seqlen = max_seqlen
- seqlen_ro, rotary_dim = cos.shape
- assert sin.shape == cos.shape
- # rotary_dim *= 2
- assert rotary_dim <= headdim, 'rotary_dim must be <= headdim'
- assert headdim <= 256, 'Only support headdim <= 256'
- assert seqlen_ro >= seqlen, 'seqlen_ro must be >= seqlen'
-
- assert (
- cos.dtype == sin.dtype
- ), f'cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}'
- assert (x.dtype == cos.dtype), (
- f'Input and cos/sin must have the same dtype, '
- f'got {x.dtype} and {cos.dtype}')
-
- cos, sin = cos.contiguous(), sin.contiguous()
- if isinstance(seqlen_offsets, torch.Tensor):
- assert seqlen_offsets.shape == (batch, )
- assert seqlen_offsets.dtype in [torch.int32, torch.int64]
- seqlen_offsets = seqlen_offsets.contiguous()
- else:
- assert seqlen_offsets + seqlen <= seqlen_ro
-
- output = torch.empty_like(x) if not inplace else x
- if rotary_dim < headdim and not inplace:
- output[..., rotary_dim:].copy_(x[..., rotary_dim:])
-
- BLOCK_K = (32 if rotary_dim <= 32 else
- (64 if rotary_dim <= 64 else
- (128 if rotary_dim <= 128 else 256)))
-
- def grid(META):
- return (triton.cdiv(seqlen, META['BLOCK_M']), batch, nheads)
-
- BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
-
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton
- # (cpu tensor?)
- with torch.cuda.device(x.device.index):
- rotary_kernel[grid](
- output, # data ptrs
- x,
- cos,
- sin,
- cu_seqlens,
- seqlen_offsets,
- seqlen, # shapes
- rotary_dim,
- seqlen_ro,
- output.stride(0)
- if not is_varlen else 0, # batch_strides if not varlen else 0
- output.stride(-3), # seqlen_stride or total_seqlen_stride
- output.stride(-2), # nheads_stride
- output.stride(-1), # headdim_stride
- x.stride(0)
- if not is_varlen else 0, # batch_strides if not varlen else 0
- x.stride(-3), # seqlen stride or total_seqlen_stride
- x.stride(-2), # nheads stride
- x.stride(-1), # headdim stride
- BLOCK_K,
- isinstance(seqlen_offsets, torch.Tensor),
- is_varlen,
- interleaved,
- conjugate,
- BLOCK_M,
- )
- return output
-
-
-class ApplyRotaryEmb(torch.autograd.Function):
-
- @staticmethod
- def forward(
- ctx,
- x,
- cos,
- sin,
- interleaved=False,
- inplace=False,
- seqlen_offsets: Union[int, torch.Tensor] = 0,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- ):
- out = apply_rotary(
- x,
- cos,
- sin,
- seqlen_offsets=seqlen_offsets,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- interleaved=interleaved,
- inplace=inplace,
- )
- if isinstance(seqlen_offsets, int):
- ctx.save_for_backward(
- cos, sin, cu_seqlens) # Can't save int with save_for_backward
- ctx.seqlen_offsets = seqlen_offsets
- else:
- ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
- ctx.seqlen_offsets = None
- ctx.interleaved = interleaved
- ctx.inplace = inplace
- ctx.max_seqlen = max_seqlen
- return out if not inplace else x
-
- @staticmethod
- def backward(ctx, do):
- seqlen_offsets = ctx.seqlen_offsets
- if seqlen_offsets is None:
- cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
- else:
- cos, sin, cu_seqlens = ctx.saved_tensors
- # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
- # "[CUDA]: invalid device context", and cloning makes it work. Idk why.
- # Triton 2.1.0 works.
- if not ctx.interleaved and not ctx.inplace:
- do = do.clone()
- dx = apply_rotary(
- do,
- cos,
- sin,
- seqlen_offsets=seqlen_offsets,
- cu_seqlens=cu_seqlens,
- max_seqlen=ctx.max_seqlen,
- interleaved=ctx.interleaved,
- inplace=ctx.inplace,
- conjugate=True,
- )
- return dx, None, None, None, None, None, None, None
-
-
-apply_rotary_emb = ApplyRotaryEmb.apply
diff --git a/code/xtuner/model/modules/dispatch/utils.py b/code/xtuner/model/modules/dispatch/utils.py
deleted file mode 100644
index 4cfa26cd1f98460a217862abe50f531389421a08..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/utils.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import torch
-import torch.nn.functional as F
-
-try:
- from flash_attn.bert_padding import index_first_axis, unpad_input
-except ImportError:
- pass
-
-
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-def upad_qkv(query_layer, key_layer, value_layer, attention_mask,
- query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
- attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis(
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim), indices_k)
- if query_length == kv_seq_len:
- # Different from the origin version as sequence parallel change
- # the number of attention heads.
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
- indices_k)
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \
- unpad_input(query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
diff --git a/code/xtuner/model/modules/dispatch/yi.py b/code/xtuner/model/modules/dispatch/yi.py
deleted file mode 100644
index 3c3e0d20ce04ee04edcf70380b8fcc220d9a7321..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/dispatch/yi.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
- # The first two dimensions of cos and sin are always 1,
- # so we can `squeeze` them.
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-def yi_attn_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- # use flash attention implemented by pytorch
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
diff --git a/code/xtuner/model/modules/projector/__init__.py b/code/xtuner/model/modules/projector/__init__.py
deleted file mode 100644
index 6196093dd5ffa4f4be0821ae2198f17a86f685f6..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/projector/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from transformers import AutoConfig, AutoModel
-
-from .configuration_projector import ProjectorConfig
-from .modeling_projector import ProjectorModel
-
-AutoConfig.register('projector', ProjectorConfig)
-AutoModel.register(ProjectorConfig, ProjectorModel)
-
-__all__ = ['ProjectorConfig', 'ProjectorModel']
diff --git a/code/xtuner/model/modules/projector/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/modules/projector/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 285ffc156afb25fc9892332990bcabc6ab7fdb31..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/modules/projector/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/modules/projector/__pycache__/configuration_projector.cpython-311.pyc b/code/xtuner/model/modules/projector/__pycache__/configuration_projector.cpython-311.pyc
deleted file mode 100644
index ed56933e2c13f6eb9349016ab1b2ae79991398f8..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/modules/projector/__pycache__/configuration_projector.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/modules/projector/__pycache__/modeling_projector.cpython-311.pyc b/code/xtuner/model/modules/projector/__pycache__/modeling_projector.cpython-311.pyc
deleted file mode 100644
index c37796aa91d84f0b3f51710041bc40074695d4f4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/modules/projector/__pycache__/modeling_projector.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/modules/projector/configuration_projector.py b/code/xtuner/model/modules/projector/configuration_projector.py
deleted file mode 100644
index f63ffdc4698bc867bd559370ea8766537270661c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/projector/configuration_projector.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from transformers import PretrainedConfig
-
-
-class ProjectorConfig(PretrainedConfig):
- model_type = 'projector'
- _auto_class = 'AutoConfig'
-
- def __init__(
- self,
- visual_hidden_size=4096,
- llm_hidden_size=4096,
- depth=2,
- hidden_act='gelu',
- bias=True,
- **kwargs,
- ):
- self.visual_hidden_size = visual_hidden_size
- self.llm_hidden_size = llm_hidden_size
- self.depth = depth
- self.hidden_act = hidden_act
- self.bias = bias
- super().__init__(**kwargs)
diff --git a/code/xtuner/model/modules/projector/modeling_projector.py b/code/xtuner/model/modules/projector/modeling_projector.py
deleted file mode 100644
index 6bd23a7a0847d8f10a62d1571e7dad440215bb06..0000000000000000000000000000000000000000
--- a/code/xtuner/model/modules/projector/modeling_projector.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.nn as nn
-from transformers import PreTrainedModel
-from transformers.activations import ACT2FN
-
-from .configuration_projector import ProjectorConfig
-from mmengine import print_log
-
-class ProjectorModel(PreTrainedModel):
- _auto_class = 'AutoModel'
- config_class = ProjectorConfig
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
-
- def __init__(self, config: ProjectorConfig) -> None:
- super().__init__(config)
- self.gradient_checkpointing = False
-
- modules = [
- nn.Linear(
- config.visual_hidden_size,
- config.llm_hidden_size,
- bias=config.bias)
- ]
- for _ in range(1, config.depth):
- modules.append(ACT2FN[config.hidden_act])
- modules.append(
- nn.Linear(
- config.llm_hidden_size,
- config.llm_hidden_size,
- bias=config.bias))
- self.model = nn.Sequential(*modules)
-
- def enable_input_require_grads(self):
- print_log("enable input required grads for projector", 'current')
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- self.model.register_forward_hook(make_inputs_require_grad)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, ProjectorModel):
- module.gradient_checkpointing = value
-
- def forward(self, x):
- if self.gradient_checkpointing and self.training:
- layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
- else:
- layer_outputs = self.model(x)
- return layer_outputs
diff --git a/code/xtuner/model/orpo.py b/code/xtuner/model/orpo.py
deleted file mode 100644
index 37264088acd7c852865e0dcd7795796bd8990eeb..0000000000000000000000000000000000000000
--- a/code/xtuner/model/orpo.py
+++ /dev/null
@@ -1,212 +0,0 @@
-# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne
-# Official code: https://github.com/xfactlab/orpo
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from mmengine import MessageHub
-from torch import nn
-
-from xtuner.parallel.sequence import (gather_forward_split_backward,
- get_sequence_parallel_group,
- get_sequence_parallel_world_size,
- split_for_sequence_parallel)
-from .sft import SupervisedFinetune
-
-
-class ORPO(SupervisedFinetune):
- """ORPO: Monolithic Preference Optimization without Reference Model
- https://arxiv.org/abs/2403.07691
-
- Args:
- beta (float): Weight of the odds_ratio_loss. Defaults to 0.1.
- """
-
- def __init__(self, *args, beta=0.1, **kwargs):
- super().__init__(*args, **kwargs)
- self.beta = beta
-
- def _gather_masked_logits(self, logits, labels, mask):
- logits = torch.gather(
- logits.log_softmax(-1), dim=2,
- index=labels.unsqueeze(2)).squeeze(2)
- return logits * mask
-
- def get_logps(
- self,
- all_logps, # bs, seqlen
- average_log_prob,
- loss_mask, # bs, seqlen
- ):
- all_logps = all_logps[:, :-1].sum(-1)
- loss_mask = loss_mask[:, :-1]
-
- if average_log_prob: # average_log_prob
- all_logps = all_logps / loss_mask.sum(-1)
-
- chosen_logps = all_logps[::2]
- rejected_logps = all_logps[1::2]
- return chosen_logps, rejected_logps
-
- def get_var_len_atten_logps(self, all_logps, average_log_prob, loss_mask,
- cu_seqlens, attention_mask):
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- # unpack sequence
- unpacked_logps = torch.split(all_logps, seqlens, dim=1)
- unpacked_loss_mask = torch.split(loss_mask, seqlens, dim=1)
- if attention_mask is not None:
- # It indicate that we pad the original sequence, labels,
- # position_ids and cumulative_len for sequence parallel if the
- # attention_mask is not None.
- # We then need to remove the padded segments.
- assert False in attention_mask
- unpacked_logps = unpacked_logps[:-1]
- unpacked_loss_mask = unpacked_loss_mask[:-1]
- assert len(unpacked_logps) % 2 == 0
-
- def compute_logps(_logps, _mask):
- _logps = _logps[:, :-1].sum(-1)
- _mask = _mask[:, :-1]
- if average_log_prob:
- _logps /= _mask.sum(-1)
- return _logps
-
- chosen_logps, rejected_logps = [], []
- for i in range(len(unpacked_logps) // 2):
- chosen = unpacked_logps[2 * i]
- rejected = unpacked_logps[2 * i + 1]
- chosen_mask = unpacked_loss_mask[2 * i]
- rejected_mask = unpacked_loss_mask[2 * i + 1]
- chosen_logps.append(compute_logps(chosen, chosen_mask))
- rejected_logps.append(compute_logps(rejected, rejected_mask))
-
- return (torch.stack(chosen_logps), torch.stack(rejected_logps))
-
- def cross_entropy_loss(self, logits, labels):
- logits = logits[..., :-1, :].contiguous()
- # labels are already shifted, now we need to remove the last dummy label # noqa
- labels = labels[..., :-1].contiguous()
- # Flatten the tokens
- loss_fct = nn.CrossEntropyLoss()
- logits = logits.view(-1, logits.shape[-1])
- labels = labels.view(-1)
- # Enable model parallelism
- labels = labels.to(logits.device)
- loss = loss_fct(logits, labels)
- return loss
-
- def odds_ratio_loss(
- self,
- chosen_logps: torch.FloatTensor,
- rejected_logps: torch.FloatTensor,
- ):
- # modified from https://github.com/huggingface/trl/blob/b031adfdb8708f1f295eab6c3f2cb910e8fe0c23/trl/trainer/orpo_trainer.py#L597 # noqa
- # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) # noqa
- log_odds = (chosen_logps - rejected_logps) - (
- torch.log1p(-torch.exp(chosen_logps)) -
- torch.log1p(-torch.exp(rejected_logps)))
- ratio = F.logsigmoid(log_odds)
- ratio = ratio[~torch.isnan(ratio)] # select valid loss
- losses = self.beta * ratio
-
- chosen_rewards = self.beta * chosen_logps
- rejected_rewards = self.beta * rejected_logps
-
- return losses, chosen_rewards, rejected_rewards, torch.mean(
- ratio), torch.mean(log_odds)
-
- @staticmethod
- def _split_for_sequence_parallel(data):
- # attention mask should not be split
- ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids', 'labels',
- 'chosen_rejected_tag')
- sp_group = get_sequence_parallel_group()
- for key in ARGS_NEED_TO_SPLIT:
- val = data.get(key, None)
- if val is not None:
- # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
- data[key] = split_for_sequence_parallel(
- val, dim=1, sp_group=sp_group)
- return data
-
- def compute_loss(self, data, data_samples=None):
- # shift labels first and add a dummy label at the end, to support sequence parallel # noqa
- data['labels'] = torch.cat(
- (data['labels'][:, 1:], torch.zeros_like(data['labels'][:, :1])),
- dim=1)
- tmp_label = data['labels'].clone()
- tmp_label[tmp_label == 0] = -100
- # loss mask of all tokens in all sp ranks
- all_loss_mask = data['labels'] != -100
-
- if self.use_varlen_attn:
- # create a chosen rejected tag for varlen_attn ce loss
- message_hub = MessageHub.get_instance('varlen_attn_args')
- rank = dist.get_rank()
- cu_seqlens = message_hub.get_info(f'cumulative_len_rank_{rank}')
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
-
- chosen_rejected_tag = torch.ones_like(data['labels'])
- unpacked_tag = list(
- torch.split(chosen_rejected_tag, seqlens, dim=1))
- # import pdb; pdb.set_trace()
- for i in range(len(unpacked_tag) // 2):
- # import pdb; pdb.set_trace()
- unpacked_tag[2 * i + 1] *= 0
- chosen_rejected_tag = torch.cat(unpacked_tag, dim=1)
- data['chosen_rejected_tag'] = chosen_rejected_tag
-
- if get_sequence_parallel_world_size() > 1:
- data = self._split_for_sequence_parallel(data)
- chosen_rejected_tag = data.pop('chosen_rejected_tag', None)
- all_logits = self.llm(**data).logits
-
- labels = data['labels'].clone()
- labels[labels == -100] = 0
- loss_mask = labels != 0 # loss mask in a single sp rank
- all_logps = self._gather_masked_logits(all_logits, labels, loss_mask)
- if get_sequence_parallel_world_size() > 1:
- all_logps = gather_forward_split_backward(
- all_logps,
- dim=1,
- sp_group=get_sequence_parallel_group(),
- grad_scale='up')
-
- if not self.use_varlen_attn:
- chosen_nll_loss = self.cross_entropy_loss(all_logits[::2],
- data['labels'][::2])
- chosen_logps, rejected_logps = self.get_logps(
- all_logps, True, all_loss_mask)
- else:
- chosen_idxs = chosen_rejected_tag == 1
- chosen_logits = all_logits[chosen_idxs]
- chosen_labels = data['labels'][chosen_idxs]
- chosen_nll_loss = self.cross_entropy_loss(chosen_logits,
- chosen_labels)
-
- chosen_logps, rejected_logps = self.get_var_len_atten_logps(
- all_logps, True, all_loss_mask, cu_seqlens,
- data['attention_mask'])
- (losses, chosen_rewards, rejected_rewards, log_odds_ratio,
- log_odds_chosen) = self.odds_ratio_loss(chosen_logps, rejected_logps)
- losses = losses.mean()
- # skip nan loss
- if torch.isnan(chosen_nll_loss):
- chosen_nll_loss = all_logits.mean() * 0
- if torch.isnan(losses):
- losses = all_logits.mean() * 0
- loss = chosen_nll_loss - losses
-
- reward_acc = (chosen_rewards > rejected_rewards).float().mean()
-
- loss_dict = {
- 'loss': loss,
- 'chosen_rewards': chosen_rewards.mean(),
- 'rejected_rewards': rejected_rewards.mean(),
- 'reward_acc': reward_acc,
- 'reward_margin': (chosen_rewards - rejected_rewards).mean(),
- 'log_odds_ratio': log_odds_ratio,
- 'log_odds_chosen': log_odds_chosen,
- 'nll_loss': chosen_nll_loss.detach().mean()
- }
- return loss_dict
diff --git a/code/xtuner/model/qwen2_perceiver_resampler.py b/code/xtuner/model/qwen2_perceiver_resampler.py
deleted file mode 100644
index 61d5c3a13a26f50e6a940046cbcebff6d971115f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/qwen2_perceiver_resampler.py
+++ /dev/null
@@ -1,618 +0,0 @@
-import math
-from typing import Optional, Tuple, Iterable, Dict
-import os, json, warnings
-import torch
-import torch.nn as nn
-from mmengine import print_log
-
-
-from transformers.models.qwen2.modeling_qwen2 import (
- Qwen2RMSNorm,
- Qwen2MLP,
- eager_attention_forward,
-)
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-
-
-# ----------------------------- Helpers ----------------------------------------
-def sinusoidal_positions(L: int, D: int, device=None, base: float = 10000.0):
- """
- Standard Transformer sinusoidal absolute positions: (1, L, D)
- Works for any D (even/odd handled by slicing).
- """
- position = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(1) # (L,1)
- div_term = torch.exp(torch.arange(0, D, 2, device=device, dtype=torch.float32) *
- (-math.log(base) / max(1, D // 2)))
- pe = torch.zeros(L, D, dtype=torch.float32, device=device)
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- return pe.unsqueeze(0) # (1, L, D)
-
-
-# ------------------------- Cross-Attention ------------------------------------
-class Qwen2CrossAttention(nn.Module):
- """
- Cross-attention that mirrors Qwen2Attention's backend interface.
-
- - Queries: concatenated [K learnable latents ⊕ text tokens]
- - Keys/Values: visual tokens
- - Backends: eager / sdpa / flash_attn2 via ALL_ATTENTION_FUNCTIONS
- - Positional handling: none here (positions are added in the caller).
- """
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
-
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_heads = config.num_attention_heads
- self.num_kv_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
-
- self.scaling = self.head_dim ** -0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = False # cross-attn is not causal
-
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True)
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=True)
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=True)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
-
- def forward(
- self,
- query_hidden_states: torch.Tensor, # (B, Q, D) -> [latents ⊕ text]
- key_value_hidden_states: torch.Tensor, # (B, N, D) -> visual tokens
- attention_mask: Optional[torch.Tensor] = None, # ignored; always None in our caller
- **kwargs,
- ):
- B, Q, _ = query_hidden_states.shape
- N = key_value_hidden_states.shape[1]
- Hd = self.head_dim
-
- # Projections
- q = self.q_proj(query_hidden_states).view(B, Q, self.num_heads, Hd).transpose(1, 2) # (B, H, Q, Hd)
- k = self.k_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd).transpose(1, 2)
- v = self.v_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd).transpose(1, 2)
-
- attention_interface = eager_attention_forward
- if getattr(self.config, "_attn_implementation", "eager") != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, _ = attention_interface(
- self,
- q, k, v,
- attention_mask=None, # <- force None (no masks)
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=None,
- **kwargs,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous().view(B, Q, self.num_heads * Hd)
- return self.o_proj(attn_output)
-
-
-# ------------------------- Perceiver Resampler --------------------------------
-class PerceiverResampler(nn.Module):
- """
- Perceiver-style resampler with:
- - Optional concatenation of text tokens to the learnable query latents
- (controlled by `concat_text_to_queries`)
- - NO attention masks
- - NO RoPE
- - A single learned absolute positional embedding applied to queries
- over either [latents ⊕ text] or [latents] only
- - Optional gradient checkpointing per block
-
- Args:
- llm: Qwen2 model (or wrapper exposing .model and get_input_embeddings()).
- num_latents: number of learnable latent slots (K).
- depth: number of Perceiver blocks.
- max_text_len: maximum supported text length (T_max).
- concat_text_to_queries: if True, queries are [latents ⊕ text]; if False, queries are [latents] only.
- """
- def __init__(
- self,
- llm,
- num_latents: int = 64,
- depth: int = 2,
- *,
- max_text_len: int = 4096,
- concat_text_to_queries: bool = True,
- ):
- super().__init__()
- base = llm.model if hasattr(llm, "model") else llm
-
- self.config = base.config
- self.hidden_size = self.config.hidden_size
- self.num_latents = num_latents
- self.depth = depth
- self.max_text_len = max_text_len
-
- # NEW: whether to append text to query slots
- self.concat_text_to_queries: bool = concat_text_to_queries
-
- # Learnable latent queries (K, D)
- self.latents = nn.Parameter(
- torch.randn(1, num_latents, self.hidden_size) / math.sqrt(self.hidden_size)
- )
-
- self.visual_ln = Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps)
-
- # Perceiver blocks
- self.blocks = nn.ModuleList()
- for i in range(depth):
- self.blocks.append(nn.ModuleDict({
- "input_ln": Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps),
- "cross_attn": Qwen2CrossAttention(self.config, layer_idx=i),
- "post_ln": Qwen2RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps),
- "mlp": Qwen2MLP(self.config),
- }))
-
- # Learned ABS positional embedding across query positions.
- # We allocate enough positions for the maximum possible query length: K + T_max.
- self.query_pos = nn.Embedding(self.num_latents + self.max_text_len, self.hidden_size)
- nn.init.normal_(self.query_pos.weight, mean=0.0, std=0.02)
-
- self.resid_scale = 1.0 / math.sqrt(2.0)
-
- # Optional (handy if you embed ids upstream)
- self.text_embed = llm.get_input_embeddings()
-
- # ---- Gradient checkpointing controls (off by default) ----
- self.gradient_checkpointing: bool = False
- self.gc_use_reentrant: bool = True
- self.gc_preserve_rng_state: bool = True
-
- # Public helpers to toggle checkpointing
- def enable_gradient_checkpointing(self, *, use_reentrant: bool = True, preserve_rng_state: bool = True):
- self.gradient_checkpointing = True
- self.gc_use_reentrant = use_reentrant
- self.gc_preserve_rng_state = preserve_rng_state
-
- def disable_gradient_checkpointing(self):
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
- self.register_forward_hook(make_inputs_require_grad)
-
- # NEW: runtime toggle
- def set_concat_text_queries(self, enabled: bool):
- """If False, only the learnable latents are used as queries."""
- self.concat_text_to_queries = enabled
-
- def forward(
- self,
- text_embeddings: torch.FloatTensor, # (B, T, D)
- visual_tokens: torch.FloatTensor, # (B, N, D)
- attention_mask: Optional[torch.Tensor] = None, # ignored
- visual_mask: Optional[torch.Tensor] = None, # ignored
- ) -> torch.Tensor:
- device = self.latents.device
-
- B, T, D = text_embeddings.shape
- assert D == self.hidden_size, f"text hidden {D} != {self.hidden_size}"
- if T > self.max_text_len:
- raise ValueError(
- f"text length {T} exceeds max_text_len={self.max_text_len}; "
- f"increase max_text_len when constructing PerceiverResampler."
- )
-
- # Queries: either [latents ⊕ text] (default) or [latents] only
- K = self.num_latents
- Q_lat = self.latents.expand(B, -1, -1) # (B, K, D)
- if self.concat_text_to_queries:
- x = torch.cat([Q_lat, text_embeddings], dim=1) # (B, K+T, D)
- else:
- x = Q_lat # (B, K, D)
- Q = x.size(1)
-
- # Learned absolute positions across queries (consume first Q positions)
- pos_ids = torch.arange(Q, device=device).unsqueeze(0).expand(B, -1) # (B, Q)
- x = x + self.query_pos(pos_ids)
-
- visual_tokens = self.visual_ln(visual_tokens)
- # Per-block forward (optionally checkpointed)
- for blk_idx, blk in enumerate(self.blocks):
- def _block_fn(t_x, t_v):
- r = t_x
- xn = blk["input_ln"](t_x)
- a = blk["cross_attn"](xn, t_v, attention_mask=None)
- t_x = r + a * self.resid_scale
- r = t_x
- xn = blk["post_ln"](t_x)
- return r + blk["mlp"](xn) * self.resid_scale
-
- if self.gradient_checkpointing and self.training:
- x = torch.utils.checkpoint.checkpoint(
- _block_fn,
- x, visual_tokens,
- use_reentrant=self.gc_use_reentrant,
- preserve_rng_state=self.gc_preserve_rng_state,
- )
- else:
- x = _block_fn(x, visual_tokens)
-
- # Return only the latent slots (K, D)
- return x[:, :self.num_latents, :]
-
-try:
- from safetensors.torch import load_file as safe_load_file
- _HAS_SAFE = True
-except Exception:
- _HAS_SAFE = False
-
-
-def _find_weight_index(ckpt_dir: str) -> Optional[str]:
- """Return the path to the model weight index (if sharded) or None."""
- cands = ["model.safetensors.index.json", "pytorch_model.bin.index.json"]
- for c in cands:
- p = os.path.join(ckpt_dir, c)
- if os.path.isfile(p):
- return p
- return None
-
-
-def _list_all_weight_files(ckpt_dir: str) -> Iterable[str]:
- """Yield all likely weight files in a directory."""
- for name in os.listdir(ckpt_dir):
- if name.endswith(".safetensors") or name.endswith(".bin"):
- # skip the top-level consolidated adapter/optimizer etc.
- if "optimizer" in name or "trainer" in name or name.endswith(".index.json"):
- continue
- yield os.path.join(ckpt_dir, name)
-
-
-def _load_shard(shard_path: str) -> Dict[str, torch.Tensor]:
- """Load one shard to CPU."""
- if shard_path.endswith(".safetensors"):
- if not _HAS_SAFE:
- raise RuntimeError("safetensors not available; install safetensors or provide .bin weights.")
- return safe_load_file(shard_path, device="cpu")
- return torch.load(shard_path, map_location="cpu")
-
-
-def _gather_needed_tensors_from_checkpoint(ckpt_dir: str, needed_keys: Iterable[str]) -> Dict[str, torch.Tensor]:
- """
- Load only the tensors we need from a (possibly sharded) HF checkpoint dir.
- """
- needed = set(needed_keys)
- out: Dict[str, torch.Tensor] = {}
-
- index_path = _find_weight_index(ckpt_dir)
- if index_path is None:
- # Non-sharded: scan files and pick keys if present
- for fpath in _list_all_weight_files(ckpt_dir):
- shard = _load_shard(fpath)
- for k in list(needed):
- if k in shard:
- out[k] = shard[k]
- needed.remove(k)
- # free asap
- del shard
- if not needed:
- break
- else:
- # Sharded: index maps param key -> shard filename
- with open(index_path, "r") as f:
- idx = json.load(f)
- weight_map = idx.get("weight_map") or idx.get("weight_map", {})
- # group by shard
- shard_to_keys: Dict[str, list] = {}
- for k in needed:
- shard_name = weight_map.get(k)
- if shard_name is None:
- continue
- shard_to_keys.setdefault(shard_name, []).append(k)
- # load per shard
- for shard_name, keys in shard_to_keys.items():
- shard_path = os.path.join(ckpt_dir, shard_name)
- shard = _load_shard(shard_path)
- for k in keys:
- if k in shard:
- out[k] = shard[k]
- del shard
-
- needed = {k for k in needed if k not in out}
-
- if needed:
- missing_sorted = "\n - " + "\n - ".join(sorted(needed))
- raise KeyError(f"Missing keys in checkpoint for Perceiver init:{missing_sorted}")
- return out
-
-
-def _copy_param_like(dst_param: torch.nn.Parameter, src_tensor: torch.Tensor):
- print_log(
- f'Copying param {dst_param.shape} <- {src_tensor.shape}',
- logger='current'
- )
- dst_param.data.copy_(src_tensor.to(dtype=dst_param.dtype, device=dst_param.device))
-
-
-def _safe_copy_linear_from_tensor(dst_linear: torch.nn.Linear, w: torch.Tensor, b: Optional[torch.Tensor]):
- if dst_linear.weight.shape != w.shape:
- raise RuntimeError(
- f"Shape mismatch copying linear: dst {tuple(dst_linear.weight.shape)} vs src {tuple(w.shape)}"
- )
- print_log(
- f'Copying linear {dst_linear.weight.shape}, bias={dst_linear.bias is not None}',
- logger='current'
- )
- dst_linear.weight.data.copy_(w.to(dtype=dst_linear.weight.dtype, device=dst_linear.weight.device))
- if dst_linear.bias is not None:
- if b is not None:
- dst_linear.bias.data.copy_(b.to(dtype=dst_linear.bias.dtype, device=dst_linear.bias.device))
- else:
- dst_linear.bias.data.zero_()
-
-
-def init_perceiver_from_llm_checkpoint(
- perceiver,
- ckpt_dir: str,
- init_from_layers: Optional[int] = None,
- layer_offset: int = 0,
-):
- """
- Initialize PerceiverResampler from the raw LLM checkpoint files on disk.
- - Supports .safetensors or .bin, sharded or single-file.
- - Copies: input/post norms, q/k/v/o, mlp gate/up/down for the first `L` layers.
- - `layer_offset` lets you start from a later LLM block if you prefer.
-
- Args:
- perceiver: PerceiverResampler instance (with .blocks ModuleList)
- ckpt_dir: path to LLM checkpoint directory (the one you pass to from_pretrained)
- init_from_layers: how many LLM layers to use (defaults to perceiver.depth)
- layer_offset: start copying from LLM layer `layer_offset` (default 0)
- """
- base_depth = perceiver.depth
- L = min(init_from_layers or base_depth, base_depth)
-
- # Build the list of keys we need from the LLM checkpoint
- needed = []
- for i in range(L):
- li = i + layer_offset
- prefix = f"model.layers.{li}"
- # norms
- needed += [
- f"{prefix}.input_layernorm.weight",
- f"{prefix}.post_attention_layernorm.weight",
- ]
- # attention
- needed += [
- f"{prefix}.self_attn.q_proj.weight",
- f"{prefix}.self_attn.k_proj.weight",
- f"{prefix}.self_attn.v_proj.weight",
- f"{prefix}.self_attn.o_proj.weight",
- ]
- # mlp
- needed += [
- f"{prefix}.mlp.gate_proj.weight",
- f"{prefix}.mlp.up_proj.weight",
- f"{prefix}.mlp.down_proj.weight",
- ]
- # optional biases (some Qwen2 variants have none; we’ll tolerate missing)
- needed += [
- f"{prefix}.self_attn.q_proj.bias",
- f"{prefix}.self_attn.k_proj.bias",
- f"{prefix}.self_attn.v_proj.bias",
- f"{prefix}.self_attn.o_proj.bias",
- f"{prefix}.mlp.gate_proj.bias",
- f"{prefix}.mlp.up_proj.bias",
- f"{prefix}.mlp.down_proj.bias",
- ]
-
- # Load what's available; we’ll allow bias keys to be missing without failing:
- try:
- tensors = _gather_needed_tensors_from_checkpoint(ckpt_dir, [k for k in needed if "bias" not in k])
- except KeyError as e:
- # Re-raise with a hint if fuse-qkv is detected
- msg = str(e)
- if "W_pack" in msg or "qkv" in msg:
- msg += "\nDetected fused QKV in checkpoint. This loader expects separate q_proj/k_proj/v_proj. "\
- "If your checkpoint uses fused QKV (e.g., *.W_pack.weight), we’ll need a small slicer—ping me."
- raise
-
- # Biases: try to load, but don't error if absent
- bias_tensors: Dict[str, torch.Tensor] = {}
- idx_path = _find_weight_index(ckpt_dir)
- if idx_path is not None:
- with open(idx_path, "r") as f:
- idx = json.load(f)
- wmap = idx.get("weight_map") or {}
- shard_to_biases: Dict[str, list] = {}
- for k in needed:
- if "bias" not in k:
- continue
- sn = wmap.get(k)
- if sn is None:
- continue
- shard_to_biases.setdefault(sn, []).append(k)
- for sn, keys in shard_to_biases.items():
- shard_path = os.path.join(ckpt_dir, sn)
- shard = _load_shard(shard_path)
- for k in keys:
- if k in shard:
- bias_tensors[k] = shard[k]
- del shard
- else:
- # non-sharded: scan files once
- for fpath in _list_all_weight_files(ckpt_dir):
- shard = _load_shard(fpath)
- for k in [k for k in needed if "bias" in k]:
- if k in shard:
- bias_tensors[k] = shard[k]
- del shard
-
- # Copy into perceiver blocks
- with torch.no_grad():
- for i in range(L):
- li = i + layer_offset
- prefix = f"model.layers.{li}"
- dst = perceiver.blocks[i]
-
- # norms
- _copy_param_like(dst["input_ln"].weight, tensors[f"{prefix}.input_layernorm.weight"])
- _copy_param_like(dst["post_ln"].weight, tensors[f"{prefix}.post_attention_layernorm.weight"])
-
- # attention
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].q_proj,
- tensors[f"{prefix}.self_attn.q_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.q_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].k_proj,
- tensors[f"{prefix}.self_attn.k_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.k_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].v_proj,
- tensors[f"{prefix}.self_attn.v_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.v_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].o_proj,
- tensors[f"{prefix}.self_attn.o_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.o_proj.bias"),
- )
-
- # mlp
- _safe_copy_linear_from_tensor(
- dst["mlp"].gate_proj,
- tensors[f"{prefix}.mlp.gate_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.gate_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["mlp"].up_proj,
- tensors[f"{prefix}.mlp.up_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.up_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["mlp"].down_proj,
- tensors[f"{prefix}.mlp.down_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.down_proj.bias"),
- )
-
-
-
-def resolve_llm_checkpoint_dir(llm, ckpt_hint: str | None = None, allow_download: bool = False) -> str | None:
- """
- Try to find a local directory for the LLM checkpoint.
- - If 'ckpt_hint' is a real directory, use it.
- - Else check llm.config._name_or_path / llm.name_or_path.
- - If those look like HF repo ids, query the local HF cache (snapshot_download with local_files_only=True).
- - If still not found and 'allow_download' is True, try downloading (optional).
- Returns an absolute directory path or None if it couldn't be resolved.
- """
- # 1) explicit hint
- if ckpt_hint and os.path.isdir(ckpt_hint):
- return os.path.abspath(ckpt_hint)
-
- candidates = []
- # Usual places Transformers stores the source string
- if getattr(getattr(llm, "config", None), "_name_or_path", None):
- candidates.append(llm.config._name_or_path)
- if getattr(llm, "name_or_path", None):
- candidates.append(llm.name_or_path)
-
- # 2) if any candidate is already a dir, take it
- for cand in candidates:
- if isinstance(cand, str) and os.path.isdir(cand):
- return os.path.abspath(cand)
-
- # 3) HF cache lookup for repo ids
- try:
- from huggingface_hub import snapshot_download
- except Exception:
- snapshot_download = None
-
- for cand in candidates:
- if not isinstance(cand, str):
- continue
- # Heuristic: repo ids usually contain '/'
- looks_like_repo_id = "/" in cand and not os.path.isabs(cand)
- if snapshot_download is None or not looks_like_repo_id:
- continue
- # First: local cache only (offline-safe)
- try:
- path = snapshot_download(
- repo_id=cand,
- local_files_only=True,
- # Narrow to model files to avoid pulling huge repos
- allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"],
- )
- return path
- except Exception:
- # Optional online fetch if allowed
- if allow_download:
- try:
- path = snapshot_download(
- repo_id=cand,
- local_files_only=False,
- allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"],
- )
- return path
- except Exception:
- pass
-
- return None
-
-def init_perceiver_from_llm(perceiver, llm, init_from_layers: int | None = None):
- """
- Copies weights from the LLM's first few layers into the PerceiverResampler blocks.
- """
- base = llm.model if hasattr(llm, "model") else llm
- depth = perceiver.depth
- L = min(init_from_layers or depth, depth, len(base.layers))
-
- with torch.no_grad():
- for i in range(L):
- src = base.layers[i] # Qwen2DecoderLayer
- dst = perceiver.blocks[i]
-
- # norms
- dst["input_ln"].weight.copy_(src.input_layernorm.weight)
- dst["post_ln"].weight.copy_(src.post_attention_layernorm.weight)
-
- # attention projections
- dst["cross_attn"].q_proj.weight.copy_(src.self_attn.q_proj.weight)
- dst["cross_attn"].q_proj.bias.copy_(src.self_attn.q_proj.bias)
- dst["cross_attn"].k_proj.weight.copy_(src.self_attn.k_proj.weight)
- dst["cross_attn"].k_proj.bias.copy_(src.self_attn.k_proj.bias)
- dst["cross_attn"].v_proj.weight.copy_(src.self_attn.v_proj.weight)
- dst["cross_attn"].v_proj.bias.copy_(src.self_attn.v_proj.bias)
- dst["cross_attn"].o_proj.weight.copy_(src.self_attn.o_proj.weight)
-
- # mlp
- dst["mlp"].gate_proj.weight.copy_(src.mlp.gate_proj.weight)
- dst["mlp"].up_proj.weight.copy_(src.mlp.up_proj.weight)
- dst["mlp"].down_proj.weight.copy_(src.mlp.down_proj.weight)
-
-def init_perceiver_from_llm_auto(
- perceiver,
- llm,
- ckpt_hint: str | None = None,
- init_from_layers: int | None = None,
- layer_offset: int = 0,
- allow_download: bool = False,
-):
- """
- Prefer initializing from the raw checkpoint on disk; if not found, fall back to
- in-memory quantization-aware init.
- """
- ckpt_dir = resolve_llm_checkpoint_dir(llm, ckpt_hint=ckpt_hint, allow_download=allow_download)
- if ckpt_dir is not None:
- print(f"[Perceiver init] Using checkpoint dir: {ckpt_dir}")
- return init_perceiver_from_llm_checkpoint(
- perceiver,
- ckpt_dir=ckpt_dir,
- init_from_layers=init_from_layers or perceiver.depth,
- layer_offset=layer_offset,
- )
- warnings.warn(
- "[Perceiver init] Could not resolve a checkpoint directory; falling back to "
- "in-memory quantization-aware initialization from the loaded LLM."
- )
- return init_perceiver_from_llm(perceiver, llm, init_from_layers=init_from_layers)
\ No newline at end of file
diff --git a/code/xtuner/model/qwen3_perceiver_resampler.py b/code/xtuner/model/qwen3_perceiver_resampler.py
deleted file mode 100644
index afa14b663f251277c37a93d82e992afada0e39d9..0000000000000000000000000000000000000000
--- a/code/xtuner/model/qwen3_perceiver_resampler.py
+++ /dev/null
@@ -1,616 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-from typing import Optional, Tuple, Iterable, Dict
-import os, json, warnings
-import torch
-import torch.nn as nn
-from mmengine import print_log
-
-# ==== Qwen3 components ====
-from transformers.models.qwen3.modeling_qwen3 import (
- Qwen3RMSNorm,
- Qwen3MLP,
- eager_attention_forward, # honors _attn_implementation (eager/sdpa/flash)
-)
-from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
-
-# ------------------------------------------------------------------------------
-# Helpers
-# ------------------------------------------------------------------------------
-def sinusoidal_positions(L: int, D: int, device=None, base: float = 10000.0):
- """Standard Transformer sinusoidal absolute positions: (1, L, D)."""
- position = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(1) # (L,1)
- div_term = torch.exp(torch.arange(0, D, 2, device=device, dtype=torch.float32) *
- (-math.log(base) / max(1, D // 2)))
- pe = torch.zeros(L, D, dtype=torch.float32, device=device)
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- return pe.unsqueeze(0) # (1, L, D)
-
-
-# ------------------------------------------------------------------------------
-# Cross-Attention (Qwen3 style)
-# ------------------------------------------------------------------------------
-class Qwen3CrossAttention(nn.Module):
- """
- Cross-attention aligned with Qwen3Attention's backend interface.
-
- - Queries: concatenated [K learnable latents ⊕ text tokens]
- - Keys/Values: visual tokens
- - Backends: eager / sdpa / flash_attn2 via ALL_ATTENTION_FUNCTIONS
- - Positional handling: none here (caller may add ABS pos to queries)
- - Not causal; no masks are used.
- """
- def __init__(self, config, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
-
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_heads = config.num_attention_heads
- self.num_kv_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
-
- self.scaling = self.head_dim ** -0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = False # cross-attn is not causal
- self.sliding_window = None # never used for cross-attn
-
- bias = getattr(config, "attention_bias", False)
-
- self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias)
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias)
-
- # Qwen3 applies per-head RMSNorm to q and k (on head_dim)
- self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
-
- def forward(
- self,
- query_hidden_states: torch.Tensor, # (B, Q, D) -> [latents ⊕ text]
- key_value_hidden_states: torch.Tensor, # (B, N, D) -> visual tokens
- attention_mask: Optional[torch.Tensor] = None, # ignored
- **kwargs,
- ):
- B, Q, _ = query_hidden_states.shape
- N = key_value_hidden_states.shape[1]
- Hd = self.head_dim
-
- # Projections -> (B, seq, heads, Hd), then Qwen3-style q_norm/k_norm, then (B, heads, seq, Hd)
- q = self.q_proj(query_hidden_states).view(B, Q, self.num_heads, Hd)
- q = self.q_norm(q).transpose(1, 2).contiguous()
-
- k = self.k_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd)
- k = self.k_norm(k).transpose(1, 2).contiguous()
-
- v = self.v_proj(key_value_hidden_states).view(B, N, self.num_kv_heads, Hd).transpose(1, 2).contiguous()
-
- attention_interface = eager_attention_forward
- if getattr(self.config, "_attn_implementation", "eager") != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
-
- attn_output, _ = attention_interface(
- self,
- q, k, v,
- attention_mask=None, # explicit: no masks for cross-attn
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- **kwargs,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous().view(B, Q, self.num_heads * Hd)
- return self.o_proj(attn_output)
-
-
-# ------------------------------------------------------------------------------
-# Perceiver Resampler (uses Qwen3 blocks)
-# ------------------------------------------------------------------------------
-class PerceiverResampler(nn.Module):
- """
- Perceiver-style resampler with:
- - NO attention masks / NO RoPE inside cross-attn
- - Single learned ABS positional embedding on queries over [latents ⊕ text]
- (optionally only [latents], see `concat_text_to_queries`)
- - Optional gradient checkpointing per block
-
- Args:
- llm: Qwen3 model (or wrapper exposing .model and get_input_embeddings()).
- num_latents: number of learnable latent slots (K).
- depth: number of Perceiver blocks.
- max_text_len: maximum supported text length (T_max).
- concat_text_to_queries: if True, queries are [latents ⊕ text]; if False, queries are [latents] only.
- """
- def __init__(
- self,
- llm,
- num_latents: int = 64,
- depth: int = 2,
- pe_gate_ratio: float = 1.0,
- pe_dropout_ratio: float = 0.1,
- *,
- max_text_len: int = 4096,
- concat_text_to_queries: bool = True, # NEW
- ):
- super().__init__()
- base = llm.model if hasattr(llm, "model") else llm
-
- self.config = base.config
- self.hidden_size = self.config.hidden_size
- self.num_latents = num_latents
- self.depth = depth
- self.max_text_len = max_text_len
-
- # NEW: controls whether text tokens are appended to the query latents
- self.concat_text_to_queries: bool = concat_text_to_queries
-
- # Learnable latent queries (K, D)
- self.latents = nn.Parameter(
- torch.randn(1, num_latents, self.hidden_size) / math.sqrt(self.hidden_size)
- )
-
- self.visual_ln = Qwen3RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps)
-
- # Perceiver blocks (Qwen3 norms/MLP + Qwen3-style cross-attn)
- self.blocks = nn.ModuleList()
- for i in range(depth):
- self.blocks.append(nn.ModuleDict({
- "input_ln": Qwen3RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps),
- "cross_attn": Qwen3CrossAttention(self.config, layer_idx=i),
- "post_ln": Qwen3RMSNorm(self.hidden_size, eps=self.config.rms_norm_eps),
- "mlp": Qwen3MLP(self.config),
- }))
-
- # Learned ABS positional embedding over queries [0 .. K + T_max - 1]
- self.query_pos = nn.Embedding(self.num_latents + self.max_text_len, self.hidden_size)
- nn.init.normal_(self.query_pos.weight, mean=0.0, std=0.02)
-
- self.pe_gate = nn.Parameter(torch.tensor(pe_gate_ratio, dtype=self.query_pos.weight.dtype))
- self.pe_dropout = nn.Dropout(pe_dropout_ratio)
- self.resid_scale = 1.0 / math.sqrt(2.0)
-
- # Optional (handy if you embed ids upstream)
- self.text_embed = llm.get_input_embeddings()
-
- # ---- Gradient checkpointing controls (off by default) ----
- self.gradient_checkpointing: bool = False
- self.gc_use_reentrant: bool = True
- self.gc_preserve_rng_state: bool = True
-
- # Public helpers to toggle checkpointing
- def enable_gradient_checkpointing(self, *, use_reentrant: bool = True, preserve_rng_state: bool = True):
- self.gradient_checkpointing = True
- self.gc_use_reentrant = use_reentrant
- self.gc_preserve_rng_state = preserve_rng_state
-
- def disable_gradient_checkpointing(self):
- self.gradient_checkpointing = False
-
- def enable_input_require_grads(self):
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
- self.register_forward_hook(make_inputs_require_grad)
-
- # NEW: runtime toggle (mirrors the other implementation)
- def set_concat_text_queries(self, enabled: bool):
- """If False, only the learnable latents are used as queries (no text concatenation)."""
- self.concat_text_to_queries = enabled
-
- def forward(
- self,
- text_embeddings: torch.FloatTensor, # (B, T, D)
- visual_tokens: torch.FloatTensor, # (B, N, D)
- attention_mask: Optional[torch.Tensor] = None, # ignored
- visual_mask: Optional[torch.Tensor] = None, # ignored
- ) -> torch.Tensor:
- device = self.latents.device
- B, T, D = text_embeddings.shape
- assert D == self.hidden_size, f"text hidden {D} != {self.hidden_size}"
- if T > self.max_text_len:
- raise ValueError(
- f"text length {T} exceeds max_text_len={self.max_text_len}; "
- f"increase max_text_len when constructing PerceiverResampler."
- )
-
- # Queries: either [latents ⊕ text] (default) or [latents] only
- K = self.num_latents
- Q_lat = self.latents.expand(B, -1, -1) # (B, K, D)
- if self.concat_text_to_queries:
- x = torch.cat([Q_lat, text_embeddings], dim=1) # (B, K+T, D)
- else:
- x = Q_lat # (B, K, D)
- Q = x.size(1)
-
- # Learned absolute positions across queries (consume first Q positions)
- pos_ids = torch.arange(Q, device=device).unsqueeze(0).expand(B, -1) # (B, Q)
- x = x + self.pe_dropout(self.query_pos(pos_ids) * self.pe_gate)
-
- visual_tokens = self.visual_ln(visual_tokens)
-
- # Per-block forward (optionally checkpointed)
- for blk_idx, blk in enumerate(self.blocks):
- def _block_fn(t_x, t_v):
- r = t_x
- xn = blk["input_ln"](t_x)
- a = blk["cross_attn"](xn, t_v, attention_mask=None)
- t_x = r + a * self.resid_scale
- r = t_x
- xn = blk["post_ln"](t_x)
- return r + blk["mlp"](xn) * self.resid_scale
-
- if self.gradient_checkpointing and self.training:
- x = torch.utils.checkpoint.checkpoint(
- _block_fn,
- x, visual_tokens,
- use_reentrant=self.gc_use_reentrant,
- preserve_rng_state=self.gc_preserve_rng_state,
- )
- else:
- x = _block_fn(x, visual_tokens)
-
- # Return only the latent slots
- return x[:, :self.num_latents, :]
-# ------------------------------------------------------------------------------
-# Checkpoint utilities
-# ------------------------------------------------------------------------------
-try:
- from safetensors.torch import load_file as safe_load_file
- _HAS_SAFE = True
-except Exception:
- _HAS_SAFE = False
-
-
-def _find_weight_index(ckpt_dir: str) -> Optional[str]:
- for c in ["model.safetensors.index.json", "pytorch_model.bin.index.json"]:
- p = os.path.join(ckpt_dir, c)
- if os.path.isfile(p):
- return p
- return None
-
-
-def _list_all_weight_files(ckpt_dir: str) -> Iterable[str]:
- for name in os.listdir(ckpt_dir):
- if name.endswith(".safetensors") or name.endswith(".bin"):
- if "optimizer" in name or "trainer" in name or name.endswith(".index.json"):
- continue
- yield os.path.join(ckpt_dir, name)
-
-
-def _load_shard(shard_path: str) -> Dict[str, torch.Tensor]:
- if shard_path.endswith(".safetensors"):
- if not _HAS_SAFE:
- raise RuntimeError("safetensors not available; install safetensors or provide .bin weights.")
- return safe_load_file(shard_path, device="cpu")
- return torch.load(shard_path, map_location="cpu")
-
-
-def _gather_needed_tensors_from_checkpoint(ckpt_dir: str, needed_keys: Iterable[str]) -> Dict[str, torch.Tensor]:
- needed = set(needed_keys)
- out: Dict[str, torch.Tensor] = {}
-
- index_path = _find_weight_index(ckpt_dir)
- if index_path is None:
- for fpath in _list_all_weight_files(ckpt_dir):
- shard = _load_shard(fpath)
- for k in list(needed):
- if k in shard:
- out[k] = shard[k]
- needed.remove(k)
- del shard
- if not needed:
- break
- else:
- with open(index_path, "r") as f:
- idx = json.load(f)
- weight_map = idx.get("weight_map") or {}
- shard_to_keys: Dict[str, list] = {}
- for k in needed:
- sn = weight_map.get(k)
- if sn is None:
- continue
- shard_to_keys.setdefault(sn, []).append(k)
- for sn, keys in shard_to_keys.items():
- shard_path = os.path.join(ckpt_dir, sn)
- shard = _load_shard(shard_path)
- for k in keys:
- if k in shard:
- out[k] = shard[k]
- del shard
- needed = {k for k in needed if k not in out}
-
- if needed:
- missing_sorted = "\n - " + "\n - ".join(sorted(needed))
- raise KeyError(f"Missing keys in checkpoint for Perceiver init:{missing_sorted}")
- return out
-
-
-def _copy_param_like(dst_param: torch.nn.Parameter, src_tensor: torch.Tensor):
- print_log(f'Copying param {dst_param.shape} <- {src_tensor.shape}', logger='current')
- dst_param.data.copy_(src_tensor.to(dtype=dst_param.dtype, device=dst_param.device))
-
-
-def _safe_copy_linear_from_tensor(dst_linear: torch.nn.Linear, w: torch.Tensor, b: Optional[torch.Tensor]):
- if dst_linear.weight.shape != w.shape:
- raise RuntimeError(
- f"Shape mismatch copying linear: dst {tuple(dst_linear.weight.shape)} vs src {tuple(w.shape)}"
- )
- print_log(
- f'Copying linear {dst_linear.weight.shape}, bias={dst_linear.bias is not None}',
- logger='current'
- )
- dst_linear.weight.data.copy_(w.to(dtype=dst_linear.weight.dtype, device=dst_linear.weight.device))
- if dst_linear.bias is not None:
- if b is not None:
- dst_linear.bias.data.copy_(b.to(dtype=dst_linear.bias.dtype, device=dst_linear.bias.device))
- else:
- dst_linear.bias.data.zero_()
-
-
-def init_perceiver_from_llm_checkpoint(
- perceiver,
- ckpt_dir: str,
- init_from_layers: Optional[int] = None,
- layer_offset: int = 0,
-):
- """
- Initialize PerceiverResampler from the raw LLM checkpoint files on disk (Qwen3 layout).
- Copies: input/post norms, q/k/v/o (plus q_norm/k_norm weights), MLP gate/up/down for the first `L` layers.
- """
- base_depth = perceiver.depth
- L = min(init_from_layers or base_depth, base_depth)
-
- needed = []
- for i in range(L):
- li = i + layer_offset
- prefix = f"model.layers.{li}"
- # norms
- needed += [
- f"{prefix}.input_layernorm.weight",
- f"{prefix}.post_attention_layernorm.weight",
- # q/k per-head norms (new in Qwen3 attention)
- f"{prefix}.self_attn.q_norm.weight",
- f"{prefix}.self_attn.k_norm.weight",
- ]
- # attention
- needed += [
- f"{prefix}.self_attn.q_proj.weight",
- f"{prefix}.self_attn.k_proj.weight",
- f"{prefix}.self_attn.v_proj.weight",
- f"{prefix}.self_attn.o_proj.weight",
- ]
- # mlp
- needed += [
- f"{prefix}.mlp.gate_proj.weight",
- f"{prefix}.mlp.up_proj.weight",
- f"{prefix}.mlp.down_proj.weight",
- ]
- # optional biases (Qwen3 defaults False, but tolerate presence)
- needed += [
- f"{prefix}.self_attn.q_proj.bias",
- f"{prefix}.self_attn.k_proj.bias",
- f"{prefix}.self_attn.v_proj.bias",
- f"{prefix}.self_attn.o_proj.bias",
- f"{prefix}.mlp.gate_proj.bias",
- f"{prefix}.mlp.up_proj.bias",
- f"{prefix}.mlp.down_proj.bias",
- ]
-
- try:
- tensors = _gather_needed_tensors_from_checkpoint(ckpt_dir, [k for k in needed if "bias" not in k])
- except KeyError as e:
- msg = str(e)
- if "W_pack" in msg or "qkv" in msg:
- msg += "\nDetected fused QKV in checkpoint. This loader expects separate q_proj/k_proj/v_proj."
- raise
-
- # Biases: best-effort
- bias_tensors: Dict[str, torch.Tensor] = {}
- idx_path = _find_weight_index(ckpt_dir)
- if idx_path is not None:
- with open(idx_path, "r") as f:
- idx = json.load(f)
- wmap = idx.get("weight_map") or {}
- shard_to_biases: Dict[str, list] = {}
- for k in needed:
- if "bias" not in k:
- continue
- sn = wmap.get(k)
- if sn is None:
- continue
- shard_to_biases.setdefault(sn, []).append(k)
- for sn, keys in shard_to_biases.items():
- shard_path = os.path.join(ckpt_dir, sn)
- shard = _load_shard(shard_path)
- for k in keys:
- if k in shard:
- bias_tensors[k] = shard[k]
- del shard
- else:
- for fpath in _list_all_weight_files(ckpt_dir):
- shard = _load_shard(fpath)
- for k in [k for k in needed if "bias" in k]:
- if k in shard:
- bias_tensors[k] = shard[k]
- del shard
-
- # Copy into perceiver blocks
- with torch.no_grad():
- for i in range(L):
- li = i + layer_offset
- prefix = f"model.layers.{li}"
- dst = perceiver.blocks[i]
-
- # norms
- _copy_param_like(dst["input_ln"].weight, tensors[f"{prefix}.input_layernorm.weight"])
- _copy_param_like(dst["post_ln"].weight, tensors[f"{prefix}.post_attention_layernorm.weight"])
-
- # per-head norms for q/k
- if f"{prefix}.self_attn.q_norm.weight" in tensors:
- _copy_param_like(dst["cross_attn"].q_norm.weight, tensors[f"{prefix}.self_attn.q_norm.weight"])
- if f"{prefix}.self_attn.k_norm.weight" in tensors:
- _copy_param_like(dst["cross_attn"].k_norm.weight, tensors[f"{prefix}.self_attn.k_norm.weight"])
-
- # attention projections
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].q_proj,
- tensors[f"{prefix}.self_attn.q_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.q_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].k_proj,
- tensors[f"{prefix}.self_attn.k_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.k_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].v_proj,
- tensors[f"{prefix}.self_attn.v_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.v_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["cross_attn"].o_proj,
- tensors[f"{prefix}.self_attn.o_proj.weight"],
- bias_tensors.get(f"{prefix}.self_attn.o_proj.bias"),
- )
-
- # mlp
- _safe_copy_linear_from_tensor(
- dst["mlp"].gate_proj,
- tensors[f"{prefix}.mlp.gate_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.gate_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["mlp"].up_proj,
- tensors[f"{prefix}.mlp.up_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.up_proj.bias"),
- )
- _safe_copy_linear_from_tensor(
- dst["mlp"].down_proj,
- tensors[f"{prefix}.mlp.down_proj.weight"],
- bias_tensors.get(f"{prefix}.mlp.down_proj.bias"),
- )
-
-
-def resolve_llm_checkpoint_dir(llm, ckpt_hint: str | None = None, allow_download: bool = False) -> str | None:
- """
- Try to find a local directory for the LLM checkpoint.
- (same logic as before; unchanged)
- """
- if ckpt_hint and os.path.isdir(ckpt_hint):
- return os.path.abspath(ckpt_hint)
-
- candidates = []
- if getattr(getattr(llm, "config", None), "_name_or_path", None):
- candidates.append(llm.config._name_or_path)
- if getattr(llm, "name_or_path", None):
- candidates.append(llm.name_or_path)
-
- for cand in candidates:
- if isinstance(cand, str) and os.path.isdir(cand):
- return os.path.abspath(cand)
-
- try:
- from huggingface_hub import snapshot_download
- except Exception:
- snapshot_download = None
-
- for cand in candidates:
- if not isinstance(cand, str):
- continue
- looks_like_repo_id = "/" in cand and not os.path.isabs(cand)
- if snapshot_download is None or not looks_like_repo_id:
- continue
- try:
- path = snapshot_download(
- repo_id=cand,
- local_files_only=True,
- allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"],
- )
- return path
- except Exception:
- if allow_download:
- try:
- path = snapshot_download(
- repo_id=cand,
- local_files_only=False,
- allow_patterns=["*.safetensors", "*.bin", "*.json", "*.index.json", "config.json"],
- )
- return path
- except Exception:
- pass
- return None
-
-
-def init_perceiver_from_llm(perceiver, llm, init_from_layers: int | None = None):
- """
- Copies weights from the LLM's first few layers into the PerceiverResampler blocks (in-memory).
- """
- base = llm.model if hasattr(llm, "model") else llm
- depth = perceiver.depth
- L = min(init_from_layers or depth, depth, len(base.layers))
-
- with torch.no_grad():
- for i in range(L):
- src = base.layers[i] # Qwen3DecoderLayer
- dst = perceiver.blocks[i]
-
- # norms
- dst["input_ln"].weight.copy_(src.input_layernorm.weight)
- dst["post_ln"].weight.copy_(src.post_attention_layernorm.weight)
-
- # per-head q/k norms
- if hasattr(src.self_attn, "q_norm"):
- dst["cross_attn"].q_norm.weight.copy_(src.self_attn.q_norm.weight)
- if hasattr(src.self_attn, "k_norm"):
- dst["cross_attn"].k_norm.weight.copy_(src.self_attn.k_norm.weight)
-
- # attention projections
- dst["cross_attn"].q_proj.weight.copy_(src.self_attn.q_proj.weight)
- if dst["cross_attn"].q_proj.bias is not None and src.self_attn.q_proj.bias is not None:
- dst["cross_attn"].q_proj.bias.copy_(src.self_attn.q_proj.bias)
-
- dst["cross_attn"].k_proj.weight.copy_(src.self_attn.k_proj.weight)
- if dst["cross_attn"].k_proj.bias is not None and src.self_attn.k_proj.bias is not None:
- dst["cross_attn"].k_proj.bias.copy_(src.self_attn.k_proj.bias)
-
- dst["cross_attn"].v_proj.weight.copy_(src.self_attn.v_proj.weight)
- if dst["cross_attn"].v_proj.bias is not None and src.self_attn.v_proj.bias is not None:
- dst["cross_attn"].v_proj.bias.copy_(src.self_attn.v_proj.bias)
-
- dst["cross_attn"].o_proj.weight.copy_(src.self_attn.o_proj.weight)
- if dst["cross_attn"].o_proj.bias is not None and src.self_attn.o_proj.bias is not None:
- dst["cross_attn"].o_proj.bias.copy_(src.self_attn.o_proj.bias)
-
- # mlp
- dst["mlp"].gate_proj.weight.copy_(src.mlp.gate_proj.weight)
- dst["mlp"].up_proj.weight.copy_(src.mlp.up_proj.weight)
- dst["mlp"].down_proj.weight.copy_(src.mlp.down_proj.weight)
-
-
-def init_perceiver_from_llm_auto(
- perceiver,
- llm,
- ckpt_hint: str | None = None,
- init_from_layers: int | None = None,
- layer_offset: int = 0,
- allow_download: bool = False,
-):
- """
- Prefer initializing from the raw checkpoint on disk; if not found, fall back to
- in-memory init from the loaded LLM.
- """
- ckpt_dir = resolve_llm_checkpoint_dir(llm, ckpt_hint=ckpt_hint, allow_download=allow_download)
- if ckpt_dir is not None:
- print(f"[Perceiver init] Using checkpoint dir: {ckpt_dir}")
- return init_perceiver_from_llm_checkpoint(
- perceiver,
- ckpt_dir=ckpt_dir,
- init_from_layers=init_from_layers or perceiver.depth,
- layer_offset=layer_offset,
- )
- warnings.warn(
- "[Perceiver init] Could not resolve a checkpoint directory; falling back to "
- "in-memory initialization from the loaded LLM."
- )
- return init_perceiver_from_llm(perceiver, llm, init_from_layers=init_from_layers)
\ No newline at end of file
diff --git a/code/xtuner/model/reward.py b/code/xtuner/model/reward.py
deleted file mode 100644
index 6bc203daa8ceb5d15be11ed6a37aa9676aa6d32d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/reward.py
+++ /dev/null
@@ -1,490 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import json
-import math
-import os
-import warnings
-from collections import OrderedDict
-from contextlib import nullcontext
-
-import torch
-import torch.distributed as dist
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from mmengine.runner import load_checkpoint
-from peft import get_peft_model, prepare_model_for_kbit_training
-from torch import nn
-from transformers import (AutoConfig, AutoModelForSequenceClassification,
- PreTrainedModel, PreTrainedTokenizer)
-from transformers.dynamic_module_utils import get_class_from_dynamic_module
-from transformers.integrations import is_deepspeed_zero3_enabled
-from transformers.modeling_utils import no_init_weights
-
-from xtuner.parallel.sequence import (gather_forward_split_backward,
- get_sequence_parallel_group,
- get_sequence_parallel_world_size,
- split_for_sequence_parallel)
-from xtuner.registry import BUILDER
-from .modules import dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, make_inputs_require_grad,
- traverse_dict)
-
-
-def reduce_mean(tensor):
- """"Obtain the mean of tensor on different GPUs."""
- if not (dist.is_available() and dist.is_initialized()):
- return tensor
- tensor = tensor.clone()
- dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
- return tensor
-
-
-def smart_tokenizer_and_embedding_resize(
- tokenizer: PreTrainedTokenizer,
- model: PreTrainedModel,
-):
- """Resize embedding."""
- if is_deepspeed_zero3_enabled():
- import deepspeed
-
- params = [model.get_input_embeddings().weight]
- if model.get_output_embeddings(
- ) is not None and not model.config.tie_word_embeddings:
- params.append(model.get_output_embeddings().weight)
-
- context_maybe_zero3 = deepspeed.zero.GatheredParameters(
- params, modifier_rank=0)
- else:
- context_maybe_zero3 = nullcontext()
-
- with context_maybe_zero3:
- current_embedding_size = model.get_input_embeddings().weight.size(0)
-
- if len(tokenizer) > current_embedding_size:
- assert isinstance(model.get_output_embeddings(), nn.Linear)
-
- model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
- with context_maybe_zero3:
- num_new_tokens = len(tokenizer) - current_embedding_size
- input_embeddings = model.get_input_embeddings().weight.data
- output_embeddings = model.get_output_embeddings().weight.data
-
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
-
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
-
- print_log(
- f'Resized token embeddings from {current_embedding_size} to '
- f'{len(tokenizer)}.', 'current')
-
-
-class RewardModel(BaseModel):
-
- def __init__(
- self,
- llm,
- lora=None,
- peft_model=None,
- use_activation_checkpointing=True,
- use_varlen_attn=False,
- tokenizer=None,
- max_position_embeddings=None,
- reward_token_id=None,
- loss_type='ranking',
- penalty_type='log_barrier',
- penalty_weight=0.01,
- ):
- super().__init__()
- with LoadWoInit():
- if isinstance(llm, dict):
- llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
- self.llm = self._build_from_cfg_or_module(llm).model
- self.v_head = nn.Linear(self.llm.config.hidden_size, 1, bias=False)
- # zero init
- self.v_head.weight.data.zero_()
-
- self.reward_token_id = reward_token_id
- assert loss_type in ('ranking',
- 'focal'), f'Unsupported loss type {loss_type}'
- self.loss_type = loss_type
- assert penalty_type in (
- 'log_barrier', 'L2',
- 'none'), f'Unsupported penalty type {penalty_type}'
- self.penalty_type = penalty_type
- self.penalty_weight = penalty_weight
-
- if tokenizer is not None:
- if isinstance(tokenizer, dict):
- tokenizer = BUILDER.build(tokenizer)
- smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
-
- self.llm.config.use_cache = False
- dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn)
-
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
- # enable gradient checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
- lora, ConfigDict):
- self.lora = BUILDER.build(lora)
- else:
- self.lora = lora
- self.peft_model = peft_model
- self.use_lora = lora is not None
- if self.use_lora:
- self._prepare_for_lora(peft_model, use_activation_checkpointing)
-
- self._is_init = True
- # Determines whether to calculate attention based on the
- # seq_len dimension (use_varlen_attn = False) or the actual length of
- # the sequence.
- self.use_varlen_attn = use_varlen_attn
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
-
- def _prepare_for_lora(self,
- peft_model=None,
- use_activation_checkpointing=True):
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if self.lora.target_modules is None:
- modules = find_all_linear_names(self.llm)
- self.lora.target_modules = modules
-
- self.llm = get_peft_model(self.llm, self.lora)
- if peft_model is not None:
- _ = load_checkpoint(self, peft_model)
-
- def init_weights(self):
- pass
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
- if not hasattr(llm_cfg, 'rope_scaling'):
- print_log('Current model does not support RoPE scaling.',
- 'current')
- return
-
- current_max_length = getattr(llm_cfg, 'max_position_embeddings', None)
- if current_max_length and max_position_embeddings > current_max_length:
- print_log(
- f'Enlarge max model length from {current_max_length} '
- f'to {max_position_embeddings}.', 'current')
- scaling_factor = float(
- math.ceil(max_position_embeddings / current_max_length))
- else:
- print_log(
- 'The input `max_position_embeddings` is smaller than '
- 'origin max length. Consider increase input length.',
- 'current')
- scaling_factor = 1.0
- cfg.rope_scaling = {'type': 'linear', 'factor': scaling_factor}
-
- return cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
- labels = data.pop('labels', None)
- if mode == 'loss':
- return self.compute_loss(data, labels)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
- hidden_states = self.llm(**data)[0]
- logits = self.v_head(hidden_states)
- return logits
-
- def predict(self, data, data_samples=None):
- hidden_states = self.llm(**data)[0]
- logits = self.v_head(hidden_states)
- logits_dict = [{'logits': log} for log in logits]
- return logits_dict
-
- @staticmethod
- def _split_for_sequence_parallel(data):
- # attention mask should not be split
- ARGS_NEED_TO_SPLIT = ('input_ids', 'position_ids')
- sp_group = get_sequence_parallel_group()
- for key in ARGS_NEED_TO_SPLIT:
- val = data.get(key, None)
- if val is not None:
- # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
- data[key] = split_for_sequence_parallel(
- val, dim=1, sp_group=sp_group)
- return data
-
- def compute_loss(self, data, labels=None):
- if get_sequence_parallel_world_size() > 1:
- data = self._split_for_sequence_parallel(data)
-
- hidden_states = self.llm(**data)[0]
- logits = self.v_head(hidden_states)
-
- if get_sequence_parallel_world_size() > 1:
- logits = gather_forward_split_backward(
- logits,
- dim=1,
- sp_group=get_sequence_parallel_group(),
- grad_scale='up')
-
- chosen_idx = torch.where(labels == 0)
- rejected_idx = torch.where(labels == 1)
- chosen_logits = logits[chosen_idx]
- rejected_logits = logits[rejected_idx]
-
- num_samples = torch.tensor(len(chosen_logits)).float().to(
- hidden_states.device)
- avg_factor = 1.0 / num_samples
- avg_factor = reduce_mean(avg_factor).to(hidden_states.device)
-
- chosen_mean = reduce_mean(chosen_logits.mean().detach())
- rejected_mean = reduce_mean(rejected_logits.mean().detach())
- acc = reduce_mean(
- (chosen_logits > rejected_logits).sum() / num_samples).detach()
- num_tokens = torch.tensor(labels.shape[1]).float()
-
- # ranking loss
- if self.loss_type == 'ranking':
- rank_loss = self.ranking_loss(
- chosen_logits, rejected_logits, avg_factor=avg_factor)
- elif self.loss_type == 'focal':
- rank_loss = self.focal_loss(
- chosen_logits, rejected_logits, avg_factor=avg_factor)
- else:
- raise NotImplementedError(
- f'Unsupported loss type {self.loss_type}')
-
- # penalty loss
- if self.penalty_type == 'log_barrier':
- penalty = self.log_barrier_penalty(
- torch.cat([chosen_logits, rejected_logits]),
- lower_bound=-5,
- upper_bound=5,
- avg_factor=avg_factor)
- elif self.penalty_type == 'L2':
- penalty = self.l2_penalty(
- torch.cat([chosen_logits, rejected_logits]),
- avg_factor=avg_factor)
- elif self.penalty_type == 'none':
- penalty = 0
- else:
- raise NotImplementedError(
- f'Unsupported penalty type {self.penalty_type}')
-
- loss = rank_loss + self.penalty_weight * penalty
- loss_dict = {
- 'loss': loss,
- 'acc': acc,
- 'chosen_score_mean': chosen_mean,
- 'rejected_score_mean': rejected_mean,
- 'num_samples': num_samples,
- 'num_tokens': num_tokens,
- }
-
- return loss_dict
-
- def ranking_loss(self, chosen_logits, rejected_logits, avg_factor):
- rank_loss = -nn.functional.logsigmoid(chosen_logits - rejected_logits)
- return rank_loss.sum() * avg_factor
-
- def focal_loss(self, chosen_logits, rejected_logits, avg_factor):
- # focal ranking loss from InternLM2 paper https://arxiv.org/abs/2403.17297 # noqa
- rank_loss = -nn.functional.logsigmoid(chosen_logits - rejected_logits)
- p_ij = torch.sigmoid(chosen_logits - rejected_logits)
- p = 2 * torch.relu(p_ij - 0.5)
- gamma = 2
- focal_loss = ((1 - p)**gamma) * rank_loss
- return focal_loss.sum() * avg_factor
-
- def log_barrier_penalty(self,
- logits,
- lower_bound,
- upper_bound,
- epsilon=1e-3,
- avg_factor=1):
- # log barrier penalty from InternLM2 paper https://arxiv.org/abs/2403.17297 # noqa
- logits_fp32 = logits.float()
- logits_clamped = torch.clamp(logits_fp32, lower_bound + epsilon,
- upper_bound - epsilon)
- penalty = -torch.log(upper_bound - logits_clamped) - torch.log(
- logits_clamped - lower_bound)
- return penalty.sum() * avg_factor
-
- def l2_penalty(self, logits, avg_factor=1):
- return (logits**2).sum() * avg_factor
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- if not self.use_lora:
- return state_dict
- to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
- return OrderedDict(to_return)
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- **kwargs):
- print(f'Saving LLM tokenizer to {save_dir}')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(save_dir)
-
- if 'PeftModel' in self.llm.__class__.__name__:
- # merge adapter
- self.llm = self.llm.merge_and_unload()
- if 'InternLM2' in self.llm.__class__.__name__:
- from xtuner.tools.model_converters.modeling_internlm2_reward.modeling_internlm2 import \
- InternLM2ForRewardModel # noqa
- print(f'Saving Reward Model to {save_dir}')
- hf_cfg = self.llm.config
- hf_cfg.reward_token_id = self.reward_token_id if \
- self.reward_token_id is not None else cfg.reward_token_id
- if not fp32:
- dtype = torch.float16
- else:
- dtype = torch.float32
- with no_init_weights():
- reward_model = InternLM2ForRewardModel._from_config(
- hf_cfg, torch_dtype=dtype)
- reward_model.model.load_state_dict(self.llm.state_dict())
- reward_model.v_head.load_state_dict(self.v_head.state_dict())
- reward_model.save_pretrained(save_dir, **save_pretrained_kwargs)
- # fix auto_map in config
- with open(os.path.join(save_dir, 'config.json')) as fp:
- config_dict = json.load(fp)
- config_dict['auto_map'][
- 'AutoModel'] = 'modeling_internlm2.InternLM2ForRewardModel'
- config_dict['auto_map'].pop('AutoModelForCausalLM', None)
- with open(os.path.join(save_dir, 'config.json'), 'w') as fp:
- json.dump(config_dict, fp, indent=2)
- else:
- warnings.warn(
- f'The pretrained model type: {self.llm.__class__.__name__} '
- 'has no reward model class defined. Use '
- 'the SequenceClassification class instead.'
- 'You can refer to `xtuner/tools/model_converters/modeling_internlm2_reward` ' # noqa
- 'to implement the reward model class.')
-
- hf_cfg = self.llm.config
- hf_cfg.num_labels = 1 # set the output dim to 1
- try:
- with no_init_weights():
- reward_model = \
- AutoModelForSequenceClassification.from_config(hf_cfg)
- except Exception as e:
- warnings.warn(f'Cannot find SequenceClassification class '
- f'from transformers: {e}, \n'
- 'try to find it in the dynamic module.')
- module_file, causal_model_name = hf_cfg.auto_map[
- 'AutoModelForCausalLM'].split('.')
- seqcls_model_name = causal_model_name.split(
- 'For')[0] + 'ForSequenceClassification'
- seqcls_class = get_class_from_dynamic_module(
- f'{module_file}.{seqcls_model_name}', hf_cfg._name_or_path)
- with no_init_weights():
- reward_model = seqcls_class(hf_cfg)
- reward_model.model.load_state_dict(self.llm.state_dict())
- reward_model.score.load_state_dict(self.v_head.state_dict())
- reward_model.save_pretrained(save_dir, **save_pretrained_kwargs)
diff --git a/code/xtuner/model/sft.py b/code/xtuner/model/sft.py
deleted file mode 100644
index 5229504891b3d921286ef0106c84ebd1349e378e..0000000000000000000000000000000000000000
--- a/code/xtuner/model/sft.py
+++ /dev/null
@@ -1,336 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-from collections import OrderedDict
-from contextlib import nullcontext
-
-import torch
-from mmengine import print_log
-from mmengine.config import Config, ConfigDict
-from mmengine.model import BaseModel
-from mmengine.runner import load_checkpoint
-from peft import get_peft_model, prepare_model_for_kbit_training
-from torch import nn
-from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
-from transformers.integrations import is_deepspeed_zero3_enabled
-
-from xtuner.parallel.sequence import (get_sequence_parallel_group,
- get_sequence_parallel_world_size,
- reduce_sequence_parallel_loss,
- split_for_sequence_parallel)
-from xtuner.registry import BUILDER
-from .modules import dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
-from .utils import (LoadWoInit, find_all_linear_names,
- get_peft_model_state_dict, make_inputs_require_grad,
- traverse_dict)
-
-
-def smart_tokenizer_and_embedding_resize(
- tokenizer: PreTrainedTokenizer,
- model: PreTrainedModel,
-):
- """Resize embedding."""
- if is_deepspeed_zero3_enabled():
- import deepspeed
-
- params = [model.get_input_embeddings().weight]
- if model.get_output_embeddings(
- ) is not None and not model.config.tie_word_embeddings:
- params.append(model.get_output_embeddings().weight)
-
- context_maybe_zero3 = deepspeed.zero.GatheredParameters(
- params, modifier_rank=0)
- else:
- context_maybe_zero3 = nullcontext()
-
- with context_maybe_zero3:
- current_embedding_size = model.get_input_embeddings().weight.size(0)
-
- if len(tokenizer) > current_embedding_size:
- assert isinstance(model.get_output_embeddings(), nn.Linear)
-
- model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
- with context_maybe_zero3:
- num_new_tokens = len(tokenizer) - current_embedding_size
- input_embeddings = model.get_input_embeddings().weight.data
- output_embeddings = model.get_output_embeddings().weight.data
-
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
-
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
-
- print_log(
- f'Resized token embeddings from {current_embedding_size} to '
- f'{len(tokenizer)}.', 'current')
-
-
-class SupervisedFinetune(BaseModel):
-
- def __init__(self,
- llm,
- lora=None,
- peft_model=None,
- use_activation_checkpointing=True,
- use_varlen_attn=False,
- tokenizer=None,
- max_position_embeddings=None):
- super().__init__()
-
- self.llm = self.build_llm_from_cfg(llm, use_varlen_attn,
- max_position_embeddings)
-
- if tokenizer is not None:
- if isinstance(tokenizer, dict):
- tokenizer = BUILDER.build(tokenizer)
- smart_tokenizer_and_embedding_resize(tokenizer, self.llm)
-
- self.llm.config.use_cache = False
- if use_activation_checkpointing:
- # For backward compatibility
- if hasattr(self.llm, 'enable_input_require_grads'):
- self.llm.enable_input_require_grads()
- else:
- self.llm.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad)
-
- # enable gradient checkpointing for memory efficiency
- self.gradient_checkpointing_enable()
-
- if isinstance(lora, dict) or isinstance(lora, Config) or isinstance(
- lora, ConfigDict):
- self.lora = BUILDER.build(lora)
- else:
- self.lora = lora
- self.peft_model = peft_model
- self.use_lora = lora is not None
- if self.use_lora:
- self._prepare_for_lora(peft_model, use_activation_checkpointing)
-
- self._is_init = True
- # Determines whether to calculate attention based on the
- # seq_len dimension (use_varlen_attn = False) or the actual length of
- # the sequence.
- self.use_varlen_attn = use_varlen_attn
-
- def build_llm_from_cfg(self, llm_cfg, use_varlen_attn,
- max_position_embeddings):
- # For forward
- with LoadWoInit():
- if isinstance(llm_cfg, dict):
- llm = self._dispatch_lm_model_cfg(llm_cfg,
- max_position_embeddings)
- llm = self._build_from_cfg_or_module(llm)
-
- llm.config.use_cache = False
- dispatch_modules(llm, use_varlen_attn=use_varlen_attn)
- return llm
-
- def gradient_checkpointing_enable(self):
- self.activation_checkpointing_enable()
-
- def activation_checkpointing_enable(self):
- self.llm.gradient_checkpointing_enable()
-
- def gradient_checkpointing_disable(self):
- self.activation_checkpointing_disable()
-
- def activation_checkpointing_disable(self):
- self.llm.gradient_checkpointing_disable()
-
- def _prepare_for_lora(self,
- peft_model=None,
- use_activation_checkpointing=True):
- self.llm = prepare_model_for_kbit_training(
- self.llm, use_activation_checkpointing)
- if self.lora.target_modules is None:
- modules = find_all_linear_names(self.llm)
- self.lora.target_modules = modules
-
- self.llm = get_peft_model(self.llm, self.lora)
- if peft_model is not None:
- _ = load_checkpoint(self, peft_model)
-
- def init_weights(self):
- pass
-
- @staticmethod
- def _prepare_for_long_context_training(cfg, llm_cfg,
- max_position_embeddings):
- if not hasattr(llm_cfg, 'rope_scaling'):
- print_log('Current model does not support RoPE scaling.',
- 'current')
- return
-
- current_max_length = getattr(llm_cfg, 'max_position_embeddings', None)
- if current_max_length and max_position_embeddings > current_max_length:
- print_log(
- f'Enlarge max model length from {current_max_length} '
- f'to {max_position_embeddings}.', 'current')
- scaling_factor = float(
- math.ceil(max_position_embeddings / current_max_length))
- else:
- print_log(
- 'The input `max_position_embeddings` is smaller than '
- 'origin max length. Consider increase input length.',
- 'current')
- scaling_factor = 1.0
- cfg.rope_scaling = {'type': 'linear', 'factor': scaling_factor}
-
- return cfg
-
- @staticmethod
- def _prepare_for_flash_attn(cfg, llm_cfg):
- cls_name = type(llm_cfg).__name__
- SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
- 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig',
- 'Starcoder2Config', 'Starcoder2Config',
- 'Phi3Config')
- SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
- 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
- 'Qwen2MoeConfig', 'Starcoder2Config',
- 'Starcoder2Config', 'Phi3Config',
- 'DeepseekV2Config')
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- if getattr(cfg, 'attn_implementation', None) is not None:
- # Flash Attention 2.0 only supports torch.float16 and
- # torch.bfloat16 dtypes
- if cfg.attn_implementation == 'flash_attention_2':
- cfg.torch_dtype = torch_dtype
- elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
- cfg.torch_dtype = torch_dtype
- cfg.attn_implementation = 'flash_attention_2'
- elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
- cfg.attn_implementation = 'sdpa'
-
- return cfg
-
- @staticmethod
- def _prepare_for_qlora_zero3(cfg):
- if (not is_deepspeed_zero3_enabled()) or (not hasattr(
- cfg, 'quantization_config')):
- return cfg
-
- torch_dtype = torch.bfloat16 if (
- torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
- else torch.float16
-
- cfg.torch_dtype = torch_dtype
- quantization_config = cfg.quantization_config
- quantization_config.bnb_4bit_compute_dtype = torch_dtype
- quantization_config.bnb_4bit_quant_storage = torch_dtype
-
- return cfg
-
- def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
- cfg = self._prepare_for_qlora_zero3(cfg)
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- llm_cfg = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
- cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
- if max_position_embeddings is not None:
- cfg = self._prepare_for_long_context_training(
- cfg, llm_cfg, max_position_embeddings)
- return cfg
-
- def _build_from_cfg_or_module(self, cfg_or_mod):
- if isinstance(cfg_or_mod, nn.Module):
- return cfg_or_mod
- elif isinstance(cfg_or_mod, dict):
- traverse_dict(cfg_or_mod)
- return BUILDER.build(cfg_or_mod)
- else:
- raise NotImplementedError
-
- def forward(self, data, data_samples=None, mode='loss'):
-
- if mode == 'loss':
- return self.compute_loss(data, data_samples)
- elif mode == 'predict':
- return self.predict(data, data_samples)
- elif mode == 'tensor':
- return self._forward(data, data_samples)
- else:
- raise NotImplementedError
-
- def _forward(self, data, data_samples=None):
-
- outputs = self.llm(**data)
-
- return outputs
-
- def predict(self, data, data_samples=None):
- outputs = self.llm(**data)
- logits_dict = [{'logits': logits} for logits in outputs.logits]
- return logits_dict
-
- @staticmethod
- def _split_for_sequence_parallel(data):
- # attention mask should not be split
- ARGS_NEED_TO_SPLIT = ('input_ids', 'labels', 'position_ids')
- sp_group = get_sequence_parallel_group()
- for key in ARGS_NEED_TO_SPLIT:
- val = data.get(key, None)
- if val is not None:
- # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
- data[key] = split_for_sequence_parallel(
- val, dim=1, sp_group=sp_group)
- return data
-
- def _compute_sequence_parallel_loss(self, data):
- data = self._split_for_sequence_parallel(data)
- outputs = self.llm(**data)
- labels = data['labels']
- num_tokens = (labels != -100).sum()
- sp_group = get_sequence_parallel_group()
- loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens,
- sp_group)
- return {'loss': loss}
-
- def compute_loss(self, data, data_samples=None):
- if get_sequence_parallel_world_size() > 1:
- return self._compute_sequence_parallel_loss(data)
- else:
- outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
-
- def state_dict(self, *args, **kwargs):
- state_dict = super().state_dict(*args, **kwargs)
- if not self.use_lora:
- return state_dict
- to_return = get_peft_model_state_dict(self.llm, state_dict=state_dict)
- return OrderedDict(to_return)
-
- def __getattr__(self, name: str):
- try:
- return super().__getattr__(name)
- except AttributeError:
- return getattr(self.llm, name)
-
- def to_hf(self,
- cfg,
- save_dir,
- fp32=False,
- save_pretrained_kwargs={},
- **kwargs):
- self.llm.config.use_cache = True
- if not fp32:
- print_log('Convert LLM to float16', 'current')
- self.llm.half()
- if self.use_lora:
- print_log(f'Saving adapter to {save_dir}', 'current')
- else:
- print_log(f'Saving LLM tokenizer to {save_dir}', 'current')
- tokenizer = BUILDER.build(cfg.tokenizer)
- tokenizer.save_pretrained(save_dir)
- print_log(f'Saving LLM to {save_dir}', 'current')
- self.llm.save_pretrained(save_dir, **save_pretrained_kwargs)
- self.llm.config.use_cache = False
diff --git a/code/xtuner/model/sparse_token_merge.py b/code/xtuner/model/sparse_token_merge.py
deleted file mode 100644
index 67627f20466dcf4e07aef731df78121c90ad0a69..0000000000000000000000000000000000000000
--- a/code/xtuner/model/sparse_token_merge.py
+++ /dev/null
@@ -1,213 +0,0 @@
-from typing import Optional, Tuple
-import torch
-import torch.nn as nn
-from mmengine import print_log
-
-class SparsePatchMerging(nn.Module):
- """
- Stable 2x2 (stride-2) patch merging for Swin/LongNet with:
- - Pre/Post LayerNorm (eps small) for bf16 stability
- - Deterministic TL/TR/BL/BR concat ordering
- - FP32 reductions for scatter/index_add
- - Variance-preserving "sum" option
- - Robust handling of ragged tokens via coords_rc + padmask
- - Optional clear-on-OOM retry
-
- forward(x: [B,L,C], coords_rc: [B,L,2] or [L,2], padmask: Optional[bool[B,L]])
- -> (x_merged: [B,L_out,C_out], coords_merged: [B,L_out,2], padmask_merged: [B,L_out])
- where C_out = 4*C if mode='concat' else C (no final projection)
- """
- def __init__(
- self,
- embed_dim: int,
- layernorm_eps: float,
- keep_dim: bool = True, # kept for API compatibility; unused now
- merge_size: int = 2,
- out_dim: Optional[int] = None, # kept for API compatibility; ignored
- pre_norm: bool = True,
- post_norm: bool = True,
- fp32_reduce: bool = True,
- mode: str = 'concat', # 'concat' or 'sum'
- clear_after_forward: bool = False,
- clear_on_oom: bool = True,
- **kwargs,
- ) -> None:
- super().__init__()
- assert merge_size == 2, "Only 2x2 merging supported."
- self.embed_dim = embed_dim
- self.merge_size = merge_size
- self.clear_after_forward = clear_after_forward
- self.clear_on_oom = clear_on_oom
- self._retrying = False
-
- # --- No final reduction: output dim is purely based on merge mode ---
- in_linear = embed_dim * 4 if mode == 'concat' else embed_dim
- self.out_dim = in_linear
-
- self.pre_norm = pre_norm
- self.post_norm = post_norm
- self.fp32_reduce = fp32_reduce
- self.mode = mode
-
- eps = layernorm_eps if layernorm_eps is not None else 1e-6
- self.ln_in = nn.LayerNorm(embed_dim, eps=eps) if pre_norm else None
- self.ln_out = nn.LayerNorm(self.out_dim, eps=eps) if post_norm else None
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- if self.ln_in is not None:
- self.ln_in.reset_parameters()
- if self.ln_out is not None:
- self.ln_out.reset_parameters()
-
- def _clear_cuda_cache(self):
- try:
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- except Exception:
- pass
-
- @staticmethod
- def _ensure_batched_coords(coords_rc: torch.Tensor, B: int) -> torch.Tensor:
- if coords_rc.dim() == 2:
- coords_rc = coords_rc.unsqueeze(0).expand(B, -1, -1)
- return coords_rc
-
- def _forward_impl(
- self,
- x: torch.Tensor, # [B, L, C]
- coords_rc: torch.Tensor, # [B, L, 2] or [L, 2]
- padmask: Optional[torch.Tensor] # [B, L] bool (True=pad)
- ):
- B, L, C = x.shape
- coords_rc = self._ensure_batched_coords(coords_rc, B)
- # print(coords_rc.shape, x.shape)
- assert coords_rc.shape[:2] == (B, L)
-
- if self.ln_in is not None:
- x = self.ln_in(x)
-
- x_dtype = x.dtype
- red_dtype = torch.float32 if self.fp32_reduce else x.dtype
-
- out_x_list = []
- out_coords_list = []
-
- key_stride = int(2**20) # enough for typical H,W<1e6
-
- for b in range(B):
- xb = x[b] # [L, C]
- coords_b = coords_rc[b].to(torch.long) # [L, 2]
- valid = torch.ones(L, dtype=torch.bool, device=x.device) if padmask is None else (~padmask[b])
-
- rb = coords_b[:, 0]
- cb = coords_b[:, 1]
- gr = torch.div(rb, self.merge_size, rounding_mode='floor')
- gc = torch.div(cb, self.merge_size, rounding_mode='floor')
- keys = gr * key_stride + gc
-
- # TL=0 (r%2=0,c%2=0), TR=1, BL=2, BR=3 — fixed ordering
- rmod = torch.remainder(rb, self.merge_size)
- cmod = torch.remainder(cb, self.merge_size)
- corner = (rmod << 1) | cmod
-
- sel = valid
- if sel.sum() == 0:
- out_x_list.append(xb.new_zeros(0, self.out_dim))
- out_coords_list.append(coords_b.new_zeros(0, 2))
- continue
-
- keys = keys[sel]
- corner = corner[sel]
- xb_sel = xb[sel]
-
- uniq, inv = torch.unique(keys, sorted=True, return_inverse=True)
- G = uniq.numel()
-
- # recover (gr,gc) from linear keys
- gc_out = torch.remainder(uniq, key_stride)
- gr_out = torch.div(uniq, key_stride, rounding_mode='floor')
- coords_out = torch.stack([gr_out, gc_out], dim=-1) # [G,2]
-
- if self.mode == 'concat':
- # do FP32 accumulation, then cast back; no projection
- out_buf = torch.zeros(G, 4, C, device=x.device, dtype=red_dtype)
- counts = torch.zeros(G, 4, device=x.device, dtype=red_dtype)
-
- for k in range(4):
- mask_k = (corner == k)
- if mask_k.any():
- gi = inv[mask_k] # [Nk]
- xk = xb_sel[mask_k].to(red_dtype) # [Nk, C]
- out_buf[:, k, :].index_add_(0, gi, xk)
- counts[:, k].index_add_(0, gi, torch.ones(gi.shape[0], device=x.device, dtype=red_dtype))
-
- counts_clamped = counts.clamp_min(1.0).unsqueeze(-1) # [G,4,1]
- out_buf = out_buf / counts_clamped
- out_feat = out_buf.reshape(G, 4*C).to(x_dtype) # [G, 4C]
-
- elif self.mode == 'sum':
- # variance-preserving sum; no projection
- out_buf = torch.zeros(G, C, device=x.device, dtype=red_dtype)
- counts = torch.zeros(G, 1, device=x.device, dtype=red_dtype)
- gi = inv
- xk = xb_sel.to(red_dtype)
- out_buf.index_add_(0, gi, xk)
- counts.index_add_(0, gi, torch.ones(gi.shape[0], device=x.device, dtype=red_dtype).unsqueeze(-1))
- scale = counts.clamp_min(1.0).sqrt().reciprocal()
- out_feat = (out_buf * scale).to(x_dtype) # [G, C]
- else:
- raise ValueError(f"Unknown mode {self.mode}")
-
- if self.ln_out is not None:
- out_feat = self.ln_out(out_feat)
-
- out_x_list.append(out_feat) # [G, out_dim]
- out_coords_list.append(coords_out) # [G, 2]
-
- # pack (pad) to max group length across batch for dense return
- Gmax = max((t.shape[0] for t in out_x_list), default=0)
- out_x = x.new_zeros(B, Gmax, self.out_dim)
- out_coords = coords_rc.new_zeros(B, Gmax, 2)
- out_mask = torch.ones(B, Gmax, dtype=torch.bool, device=x.device) # True=pad
-
- for b in range(B):
- G = out_x_list[b].shape[0]
- out_x[b, :G] = out_x_list[b]
- out_coords[b, :G] = out_coords_list[b]
- out_mask[b, :G] = False
-
- return out_x, out_coords, out_mask
-
- def forward(
- self,
- x: torch.Tensor, # [B, L, C]
- coords_rc: torch.Tensor, # [B, L, 2] or [L, 2]
- padmask: Optional[torch.Tensor] = None
- ):
- try:
- out = self._forward_impl(x, coords_rc, padmask)
- if self.clear_after_forward and x.is_cuda:
- self._clear_cuda_cache()
- return out
- except RuntimeError as e:
- msg = str(e).lower()
- is_cuda_oom = ("out of memory" in msg) or ("cuda" in msg and "alloc" in msg) or ("cublas" in msg and "alloc" in msg)
- if self.clear_on_oom and is_cuda_oom and x.is_cuda and not self._retrying:
- self._retrying = True
- self._clear_cuda_cache()
- torch.cuda.reset_peak_memory_stats()
- try:
- return self._forward_impl(x, coords_rc, padmask)
- finally:
- self._retrying = False
- raise
-
- def enable_input_require_grads(self):
- print_log("enable input required grads for patch merging", 'current')
-
- def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
- self.model.register_forward_hook(make_inputs_require_grad)
diff --git a/code/xtuner/model/torchscale/__init__.py b/code/xtuner/model/torchscale/__init__.py
deleted file mode 100644
index 3ae31e2507e8759f2ac7f85e517288f536c04ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
diff --git a/code/xtuner/model/torchscale/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index bc13f510f9542792a7ea8b15f9c0db167057feec..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/__init__.py b/code/xtuner/model/torchscale/architecture/__init__.py
deleted file mode 100644
index 3ae31e2507e8759f2ac7f85e517288f536c04ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
diff --git a/code/xtuner/model/torchscale/architecture/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/architecture/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 03cb9f95aaf85785165e4805b52efc26a7978a6a..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/architecture/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/__pycache__/config.cpython-311.pyc b/code/xtuner/model/torchscale/architecture/__pycache__/config.cpython-311.pyc
deleted file mode 100644
index b573d5da1017fb6707838b086197e9bf4aadad8e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/architecture/__pycache__/config.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/__pycache__/decoder.cpython-311.pyc b/code/xtuner/model/torchscale/architecture/__pycache__/decoder.cpython-311.pyc
deleted file mode 100644
index a2deadcd3770c2308a7419e93f0dd906618d11bf..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/architecture/__pycache__/decoder.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/__pycache__/encoder.cpython-311.pyc b/code/xtuner/model/torchscale/architecture/__pycache__/encoder.cpython-311.pyc
deleted file mode 100644
index 4cbf771a5ee7f8f9d59129e997576738337ea5cc..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/architecture/__pycache__/encoder.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/__pycache__/utils.cpython-311.pyc b/code/xtuner/model/torchscale/architecture/__pycache__/utils.cpython-311.pyc
deleted file mode 100644
index 786d6d7cb0a03d8e37141bf90b7617527c7028b5..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/architecture/__pycache__/utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/architecture/config.py b/code/xtuner/model/torchscale/architecture/config.py
deleted file mode 100644
index 36fe0223e6a2e636afc17e0875d5c37ec839d29f..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/config.py
+++ /dev/null
@@ -1,315 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-
-class EncoderConfig(object):
- def __init__(self, **kwargs):
- self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
- self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
- self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
- self.encoder_layers = kwargs.pop("encoder_layers", 12)
- self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
- self.normalize_output = kwargs.pop("normalize_output", True)
- self.activation_fn = kwargs.pop("activation_fn", "gelu")
- self.dropout = kwargs.pop("dropout", 0.0)
- self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
- self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
- self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
- self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
- self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
- self.moe_freq = kwargs.pop("moe_freq", 0)
- self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
- self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
- self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
- self.moe_eval_capacity_token_fraction = kwargs.pop(
- "moe_eval_capacity_token_fraction", 0.25
- )
- self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
- self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
- "moe_normalize_gate_prob_before_dropping", False
- )
- self.use_xmoe = kwargs.pop("use_xmoe", False)
- self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
- self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
- self.deepnorm = kwargs.pop("deepnorm", False)
- self.subln = kwargs.pop("subln", True)
- self.bert_init = kwargs.pop("bert_init", False)
- self.multiway = kwargs.pop("multiway", False)
- self.share_encoder_input_output_embed = kwargs.pop(
- "share_encoder_input_output_embed", False
- )
- self.max_source_positions = kwargs.pop("max_source_positions", 1024)
- self.no_output_layer = kwargs.pop("no_output_layer", False)
- self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
- # Text
- self.vocab_size = kwargs.pop("vocab_size", -1)
- # Vision
- self.img_size = kwargs.pop("img_size", 224)
- self.patch_size = kwargs.pop("patch_size", 16)
- self.in_chans = kwargs.pop("in_chans", 3)
- # Fairscale
- self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
- self.fsdp = kwargs.pop("fsdp", False)
- self.ddp_rank = kwargs.pop("ddp_rank", 0)
- self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
- self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
- # Dilated Attention
- self.flash_attention = kwargs.pop("flash_attention", False)
- self.segment_length = kwargs.pop("segment_length", None)
- self.dilated_ratio = kwargs.pop("dilated_ratio", None)
- self.seq_parallel = kwargs.pop("seq_parallel", False)
- self.postprocessing()
-
- def override(self, args):
- for hp in self.__dict__.keys():
- if getattr(args, hp, None) is not None:
- self.__dict__[hp] = getattr(args, hp, None)
- self.postprocessing()
-
- def postprocessing(self):
- if self.segment_length is not None and self.segment_length != '':
- self.segment_length = eval(self.segment_length)
- if self.dilated_ratio is not None and self.dilated_ratio != '':
- self.dilated_ratio = eval(self.dilated_ratio)
-
- if self.deepnorm:
- self.encoder_normalize_before = False
- self.subln = False
- if self.subln:
- self.encoder_normalize_before = True
- self.deepnorm = False
- if self.use_xmoe:
- self.moe_normalize_gate_prob_before_dropping = True
- self.moe_second_expert_policy = "random"
- assert self.moe_freq > 0 and self.moe_expert_count > 0
-
-
-class DecoderConfig(object):
- def __init__(self, **kwargs):
- self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
- self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
- self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
- self.decoder_layers = kwargs.pop("decoder_layers", 12)
- self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
- self.activation_fn = kwargs.pop("activation_fn", "gelu")
- self.dropout = kwargs.pop("dropout", 0.0)
- self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
- self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
- self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
- self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
- self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
- self.moe_freq = kwargs.pop("moe_freq", 0)
- self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
- self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
- self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
- self.moe_eval_capacity_token_fraction = kwargs.pop(
- "moe_eval_capacity_token_fraction", 0.25
- )
- self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
- self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
- "moe_normalize_gate_prob_before_dropping", False
- )
- self.use_xmoe = kwargs.pop("use_xmoe", False)
- self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
- self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
- self.deepnorm = kwargs.pop("deepnorm", False)
- self.subln = kwargs.pop("subln", True)
- self.bert_init = kwargs.pop("bert_init", False)
- self.multiway = kwargs.pop("multiway", False)
- self.share_decoder_input_output_embed = kwargs.pop(
- "share_decoder_input_output_embed", False
- )
- self.max_target_positions = kwargs.pop("max_target_positions", 1024)
- self.no_output_layer = kwargs.pop("no_output_layer", False)
- self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
- # Text
- self.vocab_size = kwargs.pop("vocab_size", -1)
- # Fairscale
- self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
- self.fsdp = kwargs.pop("fsdp", False)
- self.ddp_rank = kwargs.pop("ddp_rank", 0)
- self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
- self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
- # Dilated Attention
- self.flash_attention = kwargs.pop("flash_attention", False)
- self.segment_length = kwargs.pop("segment_length", None)
- self.dilated_ratio = kwargs.pop("dilated_ratio", None)
- self.seq_parallel = kwargs.pop("seq_parallel", False)
- self.postprocessing()
-
- def override(self, args):
- for hp in self.__dict__.keys():
- if getattr(args, hp, None) is not None:
- self.__dict__[hp] = getattr(args, hp, None)
- self.postprocessing()
-
- def postprocessing(self):
- if self.segment_length is not None and self.segment_length != '':
- self.segment_length = eval(self.segment_length)
- if self.dilated_ratio is not None and self.dilated_ratio != '':
- self.dilated_ratio = eval(self.dilated_ratio)
-
- if self.deepnorm:
- self.encoder_normalize_before = False
- self.subln = False
- if self.subln:
- self.encoder_normalize_before = True
- self.deepnorm = False
- if self.use_xmoe:
- self.moe_normalize_gate_prob_before_dropping = True
- self.moe_second_expert_policy = "random"
- assert self.moe_freq > 0 and self.moe_expert_count > 0
-
-
-class EncoderDecoderConfig(object):
- def __init__(self, **kwargs):
- self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
- self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
- self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
- self.encoder_layers = kwargs.pop("encoder_layers", 12)
- self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
- self.normalize_output = kwargs.pop("normalize_output", True)
- self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
- self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
- self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
- self.decoder_layers = kwargs.pop("decoder_layers", 12)
- self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
- self.activation_fn = kwargs.pop("activation_fn", "gelu")
- self.dropout = kwargs.pop("dropout", 0.0)
- self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
- self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
- self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
- self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
- self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
- self.moe_freq = kwargs.pop("moe_freq", 0)
- self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
- self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
- self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
- self.moe_eval_capacity_token_fraction = kwargs.pop(
- "moe_eval_capacity_token_fraction", 0.25
- )
- self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
- self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
- "moe_normalize_gate_prob_before_dropping", False
- )
- self.use_xmoe = kwargs.pop("use_xmoe", False)
- self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
- self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
- self.deepnorm = kwargs.pop("deepnorm", False)
- self.subln = kwargs.pop("subln", True)
- self.bert_init = kwargs.pop("bert_init", False)
- self.multiway = kwargs.pop("multiway", False)
- self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
- self.share_decoder_input_output_embed = kwargs.pop(
- "share_decoder_input_output_embed", False
- )
- self.max_source_positions = kwargs.pop("max_source_positions", 1024)
- self.max_target_positions = kwargs.pop("max_target_positions", 1024)
- self.no_output_layer = kwargs.pop("no_output_layer", False)
- self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
- # Text
- self.vocab_size = kwargs.pop("vocab_size", -1)
- # Fairscale
- self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
- self.fsdp = kwargs.pop("fsdp", False)
- self.ddp_rank = kwargs.pop("ddp_rank", 0)
- self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
- self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
- # Dilated Attention
- self.flash_attention = kwargs.pop("flash_attention", False)
- self.segment_length = kwargs.pop("segment_length", None)
- self.dilated_ratio = kwargs.pop("dilated_ratio", None)
- self.seq_parallel = kwargs.pop("seq_parallel", False)
- self.postprocessing()
-
- def override(self, args):
- for hp in self.__dict__.keys():
- if getattr(args, hp, None) is not None:
- self.__dict__[hp] = getattr(args, hp, None)
- self.postprocessing()
-
- def postprocessing(self):
- if self.segment_length is not None and self.segment_length != '':
- self.segment_length = eval(self.segment_length)
- if self.dilated_ratio is not None and self.dilated_ratio != '':
- self.dilated_ratio = eval(self.dilated_ratio)
-
- if self.deepnorm:
- self.encoder_normalize_before = False
- self.subln = False
- if self.subln:
- self.encoder_normalize_before = True
- self.deepnorm = False
- if self.use_xmoe:
- self.moe_normalize_gate_prob_before_dropping = True
- self.moe_second_expert_policy = "random"
- assert self.moe_freq > 0 and self.moe_expert_count > 0
-
-
-class RetNetConfig(object):
- def __init__(self, **kwargs):
- self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
- self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
- self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
- self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
- self.decoder_layers = kwargs.pop("decoder_layers", 12)
- self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
- self.activation_fn = kwargs.pop("activation_fn", "gelu")
- self.dropout = kwargs.pop("dropout", 0.0)
- self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
- self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
- self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
- self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
- self.moe_freq = kwargs.pop("moe_freq", 0)
- self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
- self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
- self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
- self.moe_eval_capacity_token_fraction = kwargs.pop(
- "moe_eval_capacity_token_fraction", 0.25
- )
- self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
- self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
- "moe_normalize_gate_prob_before_dropping", False
- )
- self.use_xmoe = kwargs.pop("use_xmoe", False)
- self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
- self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
- self.deepnorm = kwargs.pop("deepnorm", False)
- self.subln = kwargs.pop("subln", True)
- self.multiway = kwargs.pop("multiway", False)
- self.share_decoder_input_output_embed = kwargs.pop(
- "share_decoder_input_output_embed", False
- )
- self.max_target_positions = kwargs.pop("max_target_positions", 1024)
- self.no_output_layer = kwargs.pop("no_output_layer", False)
- self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
- # Blockwise
- self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
- self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
- # Text
- self.vocab_size = kwargs.pop("vocab_size", -1)
- # Fairscale
- self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
- self.fsdp = kwargs.pop("fsdp", False)
- self.ddp_rank = kwargs.pop("ddp_rank", 0)
- self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
- self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
- self.postprocessing()
-
- def override(self, args):
- for hp in self.__dict__.keys():
- if getattr(args, hp, None) is not None:
- self.__dict__[hp] = getattr(args, hp, None)
- self.postprocessing()
-
- def postprocessing(self):
- if self.deepnorm:
- self.encoder_normalize_before = False
- self.subln = False
- if self.subln:
- self.encoder_normalize_before = True
- self.deepnorm = False
- if self.use_xmoe:
- self.moe_normalize_gate_prob_before_dropping = True
- self.moe_second_expert_policy = "random"
- assert self.moe_freq > 0 and self.moe_expert_count > 0
diff --git a/code/xtuner/model/torchscale/architecture/decoder.py b/code/xtuner/model/torchscale/architecture/decoder.py
deleted file mode 100644
index 4006b0cf4fbc87e80ed79fdd4662b5e61cc8d2ba..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/decoder.py
+++ /dev/null
@@ -1,481 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import numpy as np
-import torch
-import torch.nn as nn
-from fairscale.nn import checkpoint_wrapper, wrap
-
-from torchscale.architecture.utils import init_bert_params
-from torchscale.component.droppath import DropPath
-from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
-from torchscale.component.multihead_attention import MultiheadAttention
-from torchscale.component.relative_position_bias import RelativePositionBias
-from torchscale.component.xmoe.moe_layer import MOELayer
-from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
-try:
- from apex.normalization import FusedLayerNorm as LayerNorm
-except ModuleNotFoundError:
- from torch.nn import LayerNorm
-
-class DecoderLayer(nn.Module):
- def __init__(
- self,
- args,
- depth,
- is_moe_layer=False,
- is_encoder_decoder=False,
- ):
- super().__init__()
- self.args = args
- self.embed_dim = args.decoder_embed_dim
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- if args.drop_path_rate > 0:
- drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
- depth
- ]
- self.drop_path = DropPath(drop_path_prob)
- else:
- self.drop_path = None
-
- self.self_attn = self.build_self_attention(self.embed_dim, args)
-
- self.normalize_before = args.decoder_normalize_before
-
- self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
-
- if not is_encoder_decoder:
- self.encoder_attn = None
- self.encoder_attn_layer_norm = None
- else:
- self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
- self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
-
- self.is_moe_layer = is_moe_layer
- self.ffn_dim = args.decoder_ffn_embed_dim
-
- if not self.is_moe_layer:
- self.ffn = self.build_ffn(
- self.embed_dim,
- self.args,
- )
- else:
- if args.moe_top1_expert:
- gate = Top1Gate(
- self.embed_dim,
- args.moe_expert_count,
- use_fp32=args.moe_gating_use_fp32,
- moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- else:
- gate = Top2Gate(
- self.embed_dim,
- args.moe_expert_count,
- args.moe_gating_use_fp32,
- args.moe_second_expert_policy,
- args.moe_normalize_gate_prob_before_dropping,
- args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- experts = make_experts(args, self.embed_dim, self.ffn_dim)
- self.moe_layer = MOELayer(gate, experts, args)
-
- self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
-
- if args.deepnorm:
- if is_encoder_decoder:
- self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
- else:
- self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
- else:
- self.alpha = 1.0
-
- def build_ffn(self, embed_dim, args):
- return FeedForwardNetwork(
- embed_dim,
- self.ffn_dim,
- args.activation_fn,
- args.dropout,
- args.activation_dropout,
- args.layernorm_eps,
- args.subln,
- )
-
- def build_self_attention(self, embed_dim, args):
- return MultiheadAttention(
- args,
- embed_dim,
- args.decoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- encoder_decoder_attention=False,
- subln=args.subln,
- )
-
- def build_encoder_attention(self, embed_dim, args):
- return MultiheadAttention(
- args,
- embed_dim,
- args.decoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=False,
- encoder_decoder_attention=True,
- subln=args.subln,
- )
-
- def residual_connection(self, x, residual):
- return residual * self.alpha + x
-
- def forward(
- self,
- x,
- encoder_out=None,
- encoder_padding_mask=None,
- incremental_state=None,
- self_attn_mask=None,
- self_attn_padding_mask=None,
- self_attn_rel_pos=None,
- cross_attn_rel_pos=None,
- is_first_step=False,
- ):
- residual = x
- if self.normalize_before:
- x = self.self_attn_layer_norm(x)
-
- x, attn = self.self_attn(
- query=x,
- key=x,
- value=x,
- key_padding_mask=self_attn_padding_mask,
- incremental_state=incremental_state,
- attn_mask=self_attn_mask,
- rel_pos=self_attn_rel_pos,
- is_first_step=is_first_step,
- is_causal=True,
- )
- x = self.dropout_module(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.self_attn_layer_norm(x)
-
- if self.encoder_attn is not None and encoder_out is not None:
- residual = x
- if self.normalize_before:
- x = self.encoder_attn_layer_norm(x)
-
- x, attn = self.encoder_attn(
- query=x,
- key=encoder_out,
- value=encoder_out,
- key_padding_mask=encoder_padding_mask,
- incremental_state=None,
- rel_pos=cross_attn_rel_pos,
- )
- x = self.dropout_module(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.encoder_attn_layer_norm(x)
-
- residual = x
- if self.normalize_before:
- x = self.final_layer_norm(x)
- if not self.is_moe_layer:
- x = self.ffn(x)
- l_aux = None
- else:
- x, l_aux = self.moe_layer(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.final_layer_norm(x)
-
- return x, attn, None, l_aux
-
-
-class Decoder(nn.Module):
- def __init__(
- self,
- args,
- embed_tokens=None,
- embed_positions=None,
- output_projection=None,
- is_encoder_decoder=False,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.args = args
-
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- embed_dim = args.decoder_embed_dim
- self.embed_dim = embed_dim
- self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
-
- self.embed_tokens = embed_tokens
- self.embed_positions = embed_positions
-
- if (
- output_projection is None
- and not args.no_output_layer
- and args.vocab_size > 0
- ):
- self.output_projection = self.build_output_projection(args)
- else:
- self.output_projection = output_projection
-
- if args.layernorm_embedding:
- self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
- else:
- self.layernorm_embedding = None
-
- self.layers = nn.ModuleList([])
-
- moe_freq = args.moe_freq
- for i in range(args.decoder_layers):
- is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
- self.layers.append(
- self.build_decoder_layer(
- args,
- depth=i,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- )
-
- self.num_layers = len(self.layers)
-
- if args.decoder_normalize_before:
- self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
- else:
- self.layer_norm = None
-
- self.self_attn_relative_position = None
- self.cross_attn_relative_position = None
-
- if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
- self.self_attn_relative_position = RelativePositionBias(
- num_buckets=args.rel_pos_buckets,
- max_distance=args.max_rel_pos,
- n_heads=args.decoder_attention_heads,
- )
- if is_encoder_decoder:
- self.cross_attn_relative_position = RelativePositionBias(
- num_buckets=args.rel_pos_buckets,
- max_distance=args.max_rel_pos,
- n_heads=args.decoder_attention_heads,
- )
-
- if args.bert_init:
- self.apply(init_bert_params)
-
- if args.deepnorm:
- if is_encoder_decoder:
- init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
- else:
- init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
- for name, p in self.named_parameters():
- if (
- "fc1" in name
- or "fc2" in name
- or "out_proj" in name
- or "v_proj" in name
- ):
- p.data.div_(init_scale)
-
- if args.subln:
- if is_encoder_decoder:
- init_scale = math.sqrt(math.log(args.decoder_layers * 3))
- else:
- init_scale = math.sqrt(math.log(args.decoder_layers * 2))
- for name, p in self.named_parameters():
- if "encoder_attn" in name:
- continue
- if (
- "fc1" in name
- or "fc2" in name
- or "out_proj" in name
- or "v_proj" in name
- ):
- p.data.mul_(init_scale)
-
- def build_output_projection(
- self,
- args,
- ):
- if args.share_decoder_input_output_embed:
- output_projection = torch.nn.Linear(
- self.embed_tokens.weight.shape[1],
- self.embed_tokens.weight.shape[0],
- bias=False,
- )
- output_projection.weight = self.embed_tokens.weight
- else:
- output_projection = torch.nn.Linear(
- args.decoder_embed_dim, args.vocab_size, bias=False
- )
- torch.nn.init.normal_(
- output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
- )
- return output_projection
-
- def build_decoder_layer(
- self, args, depth, is_moe_layer=False, is_encoder_decoder=False
- ):
- layer = DecoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
- def forward_embedding(
- self,
- tokens,
- token_embedding=None,
- incremental_state=None,
- ):
- positions = None
- if self.embed_positions is not None:
- positions = self.embed_positions(
- tokens, incremental_state=incremental_state
- )
-
- if incremental_state is not None and not self.is_first_step(incremental_state):
- tokens = tokens[:, -1:]
- if positions is not None:
- positions = positions[:, -1:]
-
- if token_embedding is None:
- token_embedding = self.embed_tokens(tokens)
-
- x = embed = self.embed_scale * token_embedding
-
- if positions is not None:
- x += positions
-
- if self.layernorm_embedding is not None:
- x = self.layernorm_embedding(x)
-
- x = self.dropout_module(x)
-
- return x, embed
-
- def is_first_step(self, incremental_state):
- if incremental_state is None:
- return False
- return incremental_state.get("is_first_step", False)
-
- def forward(
- self,
- prev_output_tokens,
- self_attn_padding_mask=None,
- encoder_out=None,
- incremental_state=None,
- features_only=False,
- return_all_hiddens=False,
- token_embeddings=None,
- **kwargs
- ):
- # embed tokens and positions
- x, _ = self.forward_embedding(
- prev_output_tokens, token_embeddings, incremental_state
- )
- is_first_step = self.is_first_step(incremental_state)
-
- # relative position
- self_attn_rel_pos_bias = None
- slen = prev_output_tokens.size(1)
- if self.self_attn_relative_position is not None:
- self_attn_rel_pos_bias = self.self_attn_relative_position(
- batch_size=x.size(0), qlen=slen, klen=slen
- )
- if incremental_state is not None and not is_first_step:
- self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
- cross_attn_rel_pos_bias = None
- if self.cross_attn_relative_position is not None:
- cross_attn_rel_pos_bias = self.cross_attn_relative_position(
- batch_size=x.size(0),
- qlen=slen,
- klen=encoder_out["encoder_out"].size(1),
- )
- if incremental_state is not None and not is_first_step:
- cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
-
- # decoder layers
- inner_states = [x]
-
- if encoder_out is None:
- l_aux = []
- else:
- l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
-
- for idx, layer in enumerate(self.layers):
- if incremental_state is None or is_first_step:
- if not self.args.flash_attention:
- self_attn_mask = torch.triu(
- torch.zeros([x.size(1), x.size(1)])
- .float()
- .fill_(float("-inf"))
- .type_as(x),
- 1,
- )
- else:
- self_attn_mask = None
- if is_first_step and incremental_state is not None:
- if idx not in incremental_state:
- incremental_state[idx] = {}
- else:
- self_attn_mask = None
- if idx not in incremental_state:
- incremental_state[idx] = {}
-
- x, layer_attn, _, l_aux_i = layer(
- x,
- encoder_out["encoder_out"] if encoder_out is not None else None,
- encoder_out["encoder_padding_mask"]
- if encoder_out is not None
- else None,
- incremental_state[idx] if incremental_state is not None else None,
- self_attn_mask=self_attn_mask,
- self_attn_padding_mask=self_attn_padding_mask,
- self_attn_rel_pos=self_attn_rel_pos_bias,
- cross_attn_rel_pos=cross_attn_rel_pos_bias,
- is_first_step=is_first_step,
- )
- l_aux.append(l_aux_i)
- inner_states.append(x)
-
- if self.layer_norm is not None:
- x = self.layer_norm(x)
-
- if not features_only:
- x = self.output_layer(x)
-
- return x, {
- "inner_states": inner_states,
- "l_aux": l_aux,
- "attn": None,
- }
-
- def output_layer(self, features):
- return self.output_projection(features)
diff --git a/code/xtuner/model/torchscale/architecture/encoder.py b/code/xtuner/model/torchscale/architecture/encoder.py
deleted file mode 100644
index 57cba3a578afeffa1b7972955d2213690b5f7a0d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/encoder.py
+++ /dev/null
@@ -1,400 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import numpy as np
-import torch
-import torch.nn as nn
-from fairscale.nn import checkpoint_wrapper, wrap
-try:
- from apex.normalization import FusedLayerNorm as LayerNorm
-except ModuleNotFoundError:
- from torch.nn import LayerNorm
-
-from torchscale.architecture.utils import init_bert_params
-from torchscale.component.droppath import DropPath
-from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
-from torchscale.component.multihead_attention import MultiheadAttention
-from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
-from torchscale.component.relative_position_bias import RelativePositionBias
-from torchscale.component.xmoe.moe_layer import MOELayer
-from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
-
-
-class EncoderLayer(nn.Module):
- def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
- super().__init__()
- self.args = args
- self.embed_dim = args.encoder_embed_dim
- self.self_attn = self.build_self_attention(self.embed_dim, args)
- self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- if args.drop_path_rate > 0:
- drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
- depth
- ]
- self.drop_path = DropPath(drop_path_prob)
- else:
- self.drop_path = None
-
- self.normalize_before = args.encoder_normalize_before
- self.is_moe_layer = is_moe_layer
- self.ffn_dim = args.encoder_ffn_embed_dim
-
- if not self.is_moe_layer:
- self.ffn = MultiwayWrapper(
- args,
- self.build_ffn(
- self.embed_dim,
- self.args,
- ),
- )
- else:
- assert not self.args.multiway
- if args.moe_top1_expert:
- gate = Top1Gate(
- self.embed_dim,
- args.moe_expert_count,
- use_fp32=args.moe_gating_use_fp32,
- moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- else:
- gate = Top2Gate(
- self.embed_dim,
- args.moe_expert_count,
- args.moe_gating_use_fp32,
- args.moe_second_expert_policy,
- args.moe_normalize_gate_prob_before_dropping,
- args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- experts = make_experts(args, self.embed_dim, self.ffn_dim)
- self.moe_layer = MOELayer(gate, experts, args)
- self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
-
- if args.deepnorm:
- if is_encoder_decoder:
- self.alpha = (
- math.pow(
- math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
- )
- * 0.81
- )
- else:
- self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
- else:
- self.alpha = 1.0
-
- def build_ffn(self, embed_dim, args):
- return FeedForwardNetwork(
- embed_dim,
- self.ffn_dim,
- args.activation_fn,
- args.dropout,
- args.activation_dropout,
- args.layernorm_eps,
- args.subln,
- )
-
- def build_self_attention(self, embed_dim, args):
- return MultiheadAttention(
- args,
- embed_dim,
- args.encoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- encoder_decoder_attention=False,
- subln=args.subln,
- )
-
- def residual_connection(self, x, residual):
- return residual * self.alpha + x
-
- def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None, multiway_split_position=None, incremental_state=None):
- if multiway_split_position is not None:
- assert self.args.multiway
- self.apply(set_split_position(multiway_split_position))
-
- if attn_mask is not None:
- attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
-
- residual = x
- if self.normalize_before:
- x = self.self_attn_layer_norm(x)
- x, _ = self.self_attn(
- query=x,
- key=x,
- value=x,
- key_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- rel_pos=rel_pos,
- incremental_state=incremental_state,
- )
- x = self.dropout_module(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.self_attn_layer_norm(x)
-
- residual = x
- if self.normalize_before:
- x = self.final_layer_norm(x)
- if not self.is_moe_layer:
- x = self.ffn(x)
- l_aux = None
- else:
- x = x.transpose(0, 1)
- x, l_aux = self.moe_layer(x)
- x = x.transpose(0, 1)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.final_layer_norm(x)
- return x, l_aux
-
-
-class Encoder(nn.Module):
- def __init__(
- self,
- args,
- embed_tokens=None,
- embed_positions=None,
- output_projection=None,
- is_encoder_decoder=False,
- **kwargs
- ):
- self.args = args
- super().__init__(**kwargs)
-
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- embed_dim = args.encoder_embed_dim
- self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
-
- self.embed_tokens = embed_tokens
- self.embed_positions = embed_positions
-
- if (
- output_projection is None
- and not is_encoder_decoder
- and not args.no_output_layer
- and args.vocab_size > 0
- ):
- self.output_projection = self.build_output_projection(args)
- else:
- self.output_projection = output_projection
-
- if args.layernorm_embedding:
- self.layernorm_embedding = MultiwayWrapper(
- args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1
- )
- else:
- self.layernorm_embedding = None
-
- self.layers = nn.ModuleList([])
-
- moe_freq = args.moe_freq
- for i in range(args.encoder_layers):
- is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
- self.layers.append(
- self.build_encoder_layer(
- args,
- depth=i,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- )
- self.num_layers = len(self.layers)
-
- if args.encoder_normalize_before and args.normalize_output:
- self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
- else:
- self.layer_norm = None
-
- if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
- self.relative_position = RelativePositionBias(
- num_buckets=args.rel_pos_buckets,
- max_distance=args.max_rel_pos,
- n_heads=args.encoder_attention_heads,
- )
- else:
- self.relative_position = None
-
- if args.bert_init:
- self.apply(init_bert_params)
-
- if args.deepnorm:
- if is_encoder_decoder:
- init_scale = (
- math.pow(
- math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
- )
- / 1.15
- )
- else:
- init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
- for name, p in self.named_parameters():
- if (
- "fc1" in name
- or "fc2" in name
- or "out_proj" in name
- or "v_proj" in name
- ):
- p.data.div_(init_scale)
-
- if args.subln:
- if is_encoder_decoder:
- init_scale = math.sqrt(
- math.log(3 * args.decoder_layers)
- * math.log(2 * args.encoder_layers)
- / 3
- )
- else:
- init_scale = math.sqrt(math.log(args.encoder_layers * 2))
- for name, p in self.named_parameters():
- if (
- "fc1" in name
- or "fc2" in name
- or "out_proj" in name
- or "v_proj" in name
- ):
- p.data.mul_(init_scale)
-
- def build_output_projection(
- self,
- args,
- ):
- if args.share_encoder_input_output_embed:
- assert args.encoder_embedding_type == "language"
- output_projection = torch.nn.Linear(
- self.embed_tokens.weight.shape[1],
- self.embed_tokens.weight.shape[0],
- bias=False,
- )
- output_projection.weight = self.embed_tokens.weight
- else:
- output_projection = torch.nn.Linear(
- args.encoder_embed_dim, args.vocab_size, bias=False
- )
- torch.nn.init.normal_(
- output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
- )
- return output_projection
-
- def build_encoder_layer(
- self, args, depth, is_moe_layer=False, is_encoder_decoder=False
- ):
- layer = EncoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
- def forward_embedding(
- self,
- src_tokens,
- token_embedding=None,
- positions=None,
- ):
- if token_embedding is None:
- token_embedding = self.embed_tokens(src_tokens)
- x = embed = self.embed_scale * token_embedding
- if self.embed_positions is not None:
- if src_tokens is not None:
- x = embed + self.embed_positions(src_tokens, positions=positions)
- else:
- x = embed + self.embed_positions(x, positions=positions)
- if self.layernorm_embedding is not None:
- x = self.layernorm_embedding(x)
- x = self.dropout_module(x)
- return x, embed
-
- def forward(
- self,
- src_tokens,
- encoder_padding_mask=None,
- attn_mask=None,
- return_all_hiddens=False,
- token_embeddings=None,
- multiway_split_position=None,
- features_only=False,
- incremental_state=None,
- positions=None,
- **kwargs
- ):
- assert src_tokens is not None or token_embeddings is not None
-
- if encoder_padding_mask is None:
- if src_tokens is not None:
- encoder_padding_mask = torch.zeros_like(
- src_tokens, device=src_tokens.device
- ).bool()
- else:
- encoder_padding_mask = torch.zeros(
- [token_embeddings.size(0), token_embeddings.size(1)],
- device=token_embeddings.device,
- ).bool()
-
- if multiway_split_position is not None:
- assert self.args.multiway
- self.apply(set_split_position(multiway_split_position))
-
- x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
- x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
-
- encoder_states = []
-
- if return_all_hiddens:
- encoder_states.append(x)
-
- rel_pos_bias = None
- if self.relative_position is not None:
- rel_pos_bias = self.relative_position(
- batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
- )
-
- # incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
- l_aux = []
- for idx, layer in enumerate(self.layers):
- x, l_aux_i = layer(
- x,
- encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
- attn_mask=attn_mask,
- rel_pos=rel_pos_bias,
- multiway_split_position=multiway_split_position,
- incremental_state=incremental_state[idx] if incremental_state is not None else None,
- )
- if return_all_hiddens:
- assert encoder_states is not None
- encoder_states.append(x)
- l_aux.append(l_aux_i)
-
- if self.layer_norm is not None:
- x = self.layer_norm(x)
-
- if not features_only and self.output_projection is not None:
- print('doing output projection')
- x = self.output_projection(x)
-
- return {
- "encoder_out": x,
- "encoder_embedding": encoder_embedding,
- "encoder_padding_mask": encoder_padding_mask,
- "encoder_states": encoder_states,
- "l_aux": l_aux,
- }
diff --git a/code/xtuner/model/torchscale/architecture/encoder_decoder.py b/code/xtuner/model/torchscale/architecture/encoder_decoder.py
deleted file mode 100644
index 91a906ec4a5acec203161443c563832cedda7a9c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/encoder_decoder.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch.nn as nn
-
-from torchscale.architecture.decoder import Decoder
-from torchscale.architecture.encoder import Encoder
-
-
-class EncoderDecoder(nn.Module):
- def __init__(
- self,
- args,
- encoder_embed_tokens=None,
- encoder_embed_positions=None,
- decoder_embed_tokens=None,
- decoder_embed_positions=None,
- output_projection=None,
- **kwargs
- ):
- super().__init__()
- self.args = args
- if args.share_all_embeddings:
- args.share_decoder_input_output_embed = True
-
- self.encoder = Encoder(
- args,
- encoder_embed_tokens,
- encoder_embed_positions,
- is_encoder_decoder=True,
- **kwargs
- )
-
- if args.share_all_embeddings and decoder_embed_tokens is None:
- decoder_embed_tokens = self.encoder.embed_tokens
-
- self.decoder = Decoder(
- args,
- decoder_embed_tokens,
- decoder_embed_positions,
- output_projection,
- is_encoder_decoder=True,
- **kwargs
- )
-
- def forward(
- self,
- src_tokens,
- prev_output_tokens,
- return_all_hiddens=False,
- features_only=False,
- **kwargs
- ):
- encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
- decoder_out = self.decoder(
- prev_output_tokens,
- encoder_out=encoder_out,
- features_only=features_only,
- return_all_hiddens=return_all_hiddens,
- )
- return decoder_out
diff --git a/code/xtuner/model/torchscale/architecture/retnet.py b/code/xtuner/model/torchscale/architecture/retnet.py
deleted file mode 100644
index b29928cb1f2fdb883b9d0cfdbb778295376176e1..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/retnet.py
+++ /dev/null
@@ -1,391 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from fairscale.nn import checkpoint_wrapper, wrap
-
-from torchscale.architecture.utils import init_bert_params
-from torchscale.component.droppath import DropPath
-from torchscale.component.feedforward_network import make_experts
-from torchscale.component.gate_linear_unit import GLU
-from torchscale.component.multiscale_retention import MultiScaleRetention
-from torchscale.component.xmoe.moe_layer import MOELayer
-from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
-from torchscale.component.rms_norm import RMSNorm
-
-
-class RetNetRelPos(nn.Module):
- def __init__(self, args):
- super().__init__()
- angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
- angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
- decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
- self.register_buffer("angle", angle)
- self.register_buffer("decay", decay)
- self.recurrent_chunk_size = args.recurrent_chunk_size
-
- def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
- if activate_recurrent:
- sin = torch.sin(self.angle * (slen - 1))
- cos = torch.cos(self.angle * (slen - 1))
- retention_rel_pos = ((sin, cos), self.decay.exp())
- elif chunkwise_recurrent:
- index = torch.arange(slen).to(self.decay)
- sin = torch.sin(index[:, None] * self.angle[None, :])
- cos = torch.cos(index[:, None] * self.angle[None, :])
-
- block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
- mask = torch.tril(torch.ones(self.recurrent_chunk_size, self.recurrent_chunk_size).to(self.decay))
- mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
- mask = torch.exp(mask * self.decay[:, None, None])
- mask = torch.nan_to_num(mask)
-
- value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
- value_inner_decay = value_inner_decay.unsqueeze(-1)
- scale = mask.sum(dim=-1, keepdim=True).sqrt()
- inner_mask = mask / scale
-
- cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
- query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
- query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
- cross_decay = cross_decay[:, None, None]
- retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
- else:
- index = torch.arange(slen).to(self.decay)
- sin = torch.sin(index[:, None] * self.angle[None, :])
- cos = torch.cos(index[:, None] * self.angle[None, :])
- mask = torch.tril(torch.ones(slen, slen).to(self.decay))
- mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
- mask = torch.exp(mask * self.decay[:, None, None])
- mask = torch.nan_to_num(mask)
- mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
- retention_rel_pos = ((sin, cos), mask)
-
- return retention_rel_pos
-
-class DecoderLayer(nn.Module):
- def __init__(
- self,
- args,
- depth,
- is_moe_layer=False,
- ):
- super().__init__()
- self.args = args
- self.embed_dim = args.decoder_embed_dim
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- if args.drop_path_rate > 0:
- drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
- depth
- ]
- self.drop_path = DropPath(drop_path_prob)
- else:
- self.drop_path = None
-
- self.retention = self.build_retention(self.embed_dim, args)
-
- self.normalize_before = args.decoder_normalize_before
-
- self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
-
- self.is_moe_layer = is_moe_layer
- self.ffn_dim = args.decoder_ffn_embed_dim
-
- if not self.is_moe_layer:
- self.ffn = self.build_ffn(
- self.embed_dim,
- self.args,
- )
- else:
- if args.moe_top1_expert:
- gate = Top1Gate(
- self.embed_dim,
- args.moe_expert_count,
- use_fp32=args.moe_gating_use_fp32,
- moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- else:
- gate = Top2Gate(
- self.embed_dim,
- args.moe_expert_count,
- args.moe_gating_use_fp32,
- args.moe_second_expert_policy,
- args.moe_normalize_gate_prob_before_dropping,
- args.moe_eval_capacity_token_fraction,
- use_xmoe=args.use_xmoe,
- )
- experts = make_experts(args, self.embed_dim, self.ffn_dim)
- self.moe_layer = MOELayer(gate, experts, args)
-
- self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
-
- if args.deepnorm:
- self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
- else:
- self.alpha = 1.0
-
- def build_ffn(self, embed_dim, args):
- return GLU(
- embed_dim,
- self.ffn_dim,
- args.activation_fn,
- args.dropout,
- args.activation_dropout,
- )
-
- def build_retention(self, embed_dim, args):
- return MultiScaleRetention(
- args,
- embed_dim,
- args.decoder_value_embed_dim,
- args.decoder_retention_heads,
- )
-
- def residual_connection(self, x, residual):
- return residual * self.alpha + x
-
- def forward(
- self,
- x,
- incremental_state=None,
- chunkwise_recurrent=False,
- retention_rel_pos=None,
- ):
- residual = x
- if self.normalize_before:
- x = self.retention_layer_norm(x)
-
- x = self.retention(
- x,
- incremental_state=incremental_state,
- rel_pos=retention_rel_pos,
- chunkwise_recurrent=chunkwise_recurrent,
- )
- x = self.dropout_module(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.retention_layer_norm(x)
-
- residual = x
- if self.normalize_before:
- x = self.final_layer_norm(x)
- if not self.is_moe_layer:
- x = self.ffn(x)
- l_aux = None
- else:
- x, l_aux = self.moe_layer(x)
-
- if self.drop_path is not None:
- x = self.drop_path(x)
-
- x = self.residual_connection(x, residual)
- if not self.normalize_before:
- x = self.final_layer_norm(x)
-
- return x, l_aux
-
-
-class RetNetDecoder(nn.Module):
- def __init__(
- self,
- args,
- embed_tokens=None,
- output_projection=None,
- **kwargs
- ):
- super().__init__(**kwargs)
- self.args = args
-
- self.dropout_module = torch.nn.Dropout(args.dropout)
-
- embed_dim = args.decoder_embed_dim
- self.embed_dim = embed_dim
- self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
-
- self.embed_tokens = embed_tokens
-
- if (
- output_projection is None
- and not args.no_output_layer
- and args.vocab_size > 0
- ):
- self.output_projection = self.build_output_projection(args)
- else:
- self.output_projection = output_projection
-
- if args.layernorm_embedding:
- self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
- else:
- self.layernorm_embedding = None
-
- self.layers = nn.ModuleList([])
-
- moe_freq = args.moe_freq
- for i in range(args.decoder_layers):
- is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
- self.layers.append(
- self.build_decoder_layer(
- args,
- depth=i,
- is_moe_layer=is_moe_layer,
- )
- )
-
- self.num_layers = len(self.layers)
-
- if args.decoder_normalize_before:
- self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
- else:
- self.layer_norm = None
-
- self.retnet_rel_pos = RetNetRelPos(args)
- self.chunkwise_recurrent = args.chunkwise_recurrent
- self.recurrent_chunk_size = args.recurrent_chunk_size
-
-
- if args.deepnorm:
- init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
- for name, p in self.named_parameters():
- if (
- "fc1" in name
- or "fc2" in name
- or "out_proj" in name
- or "v_proj" in name
- ):
- p.data.div_(init_scale)
-
- def build_output_projection(
- self,
- args,
- ):
- if args.share_decoder_input_output_embed:
- output_projection = torch.nn.Linear(
- self.embed_tokens.weight.shape[1],
- self.embed_tokens.weight.shape[0],
- bias=False,
- )
- output_projection.weight = self.embed_tokens.weight
- else:
- output_projection = torch.nn.Linear(
- args.decoder_embed_dim, args.vocab_size, bias=False
- )
- torch.nn.init.normal_(
- output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
- )
- return output_projection
-
- def build_decoder_layer(
- self, args, depth, is_moe_layer=False
- ):
- layer = DecoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
- def forward_embedding(
- self,
- tokens,
- token_embedding=None,
- incremental_state=None,
- ):
- if incremental_state is not None and not self.is_first_step(incremental_state):
- tokens = tokens[:, -1:]
-
- if token_embedding is None:
- token_embedding = self.embed_tokens(tokens)
-
- x = embed = self.embed_scale * token_embedding
-
- if self.layernorm_embedding is not None:
- x = self.layernorm_embedding(x)
-
- x = self.dropout_module(x)
-
- return x, embed
-
- def is_first_step(self, incremental_state):
- if incremental_state is None:
- return False
- return incremental_state.get("is_first_step", False)
-
- def forward(
- self,
- prev_output_tokens,
- incremental_state=None,
- features_only=False,
- return_all_hiddens=False,
- token_embeddings=None,
- **kwargs
- ):
- # embed tokens
- x, _ = self.forward_embedding(
- prev_output_tokens, token_embeddings, incremental_state
- )
- is_first_step = self.is_first_step(incremental_state)
-
-
- if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
- padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
- slen = prev_output_tokens.size(1) + padding_len
- x = F.pad(x, (0, 0, 0, padding_len))
- else:
- slen = prev_output_tokens.size(1)
- # relative position
- retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
- # decoder layers
- inner_states = [x]
-
- l_aux = []
-
- for idx, layer in enumerate(self.layers):
- if incremental_state is None or is_first_step:
- if is_first_step and incremental_state is not None:
- if idx not in incremental_state:
- incremental_state[idx] = {}
- else:
- if idx not in incremental_state:
- incremental_state[idx] = {}
-
- x, l_aux_i = layer(
- x,
- incremental_state[idx] if incremental_state is not None else None,
- retention_rel_pos=retention_rel_pos,
- chunkwise_recurrent=self.chunkwise_recurrent,
- )
- l_aux.append(l_aux_i)
- inner_states.append(x)
-
- if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
- x = x[:, :prev_output_tokens.size(1), :]
-
- if self.layer_norm is not None:
- x = self.layer_norm(x)
-
- if not features_only:
- x = self.output_layer(x)
-
- return x, {
- "inner_states": inner_states,
- "l_aux": l_aux,
- "attn": None,
- }
-
- def output_layer(self, features):
- return self.output_projection(features)
diff --git a/code/xtuner/model/torchscale/architecture/utils.py b/code/xtuner/model/torchscale/architecture/utils.py
deleted file mode 100644
index 58a5c15be538bc6ca0526f6755c28462b5d8e263..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/architecture/utils.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch.nn as nn
-
-from torchscale.component.multihead_attention import MultiheadAttention
-from torchscale.component.multiway_network import MultiwayNetwork
-
-
-def init_bert_params(module):
- def normal_(data):
- data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
-
- if isinstance(module, nn.Linear):
- normal_(module.weight.data)
- if module.bias is not None:
- module.bias.data.zero_()
- if isinstance(module, nn.Embedding):
- normal_(module.weight.data)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- if isinstance(module, MultiheadAttention):
- if isinstance(module.q_proj, MultiwayNetwork):
- normal_(module.q_proj.A.weight.data)
- normal_(module.q_proj.B.weight.data)
- normal_(module.k_proj.A.weight.data)
- normal_(module.k_proj.B.weight.data)
- normal_(module.v_proj.A.weight.data)
- normal_(module.v_proj.B.weight.data)
- else:
- normal_(module.q_proj.weight.data)
- normal_(module.k_proj.weight.data)
- normal_(module.v_proj.weight.data)
diff --git a/code/xtuner/model/torchscale/component/__init__.py b/code/xtuner/model/torchscale/component/__init__.py
deleted file mode 100644
index 3ae31e2507e8759f2ac7f85e517288f536c04ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
diff --git a/code/xtuner/model/torchscale/component/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 9fa67e789758407b21b307d21f74708d45ed7152..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/dilated_attention.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/dilated_attention.cpython-311.pyc
deleted file mode 100644
index 54b53ffacc9e0c8645886bcb3c9334b6580d99f9..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/dilated_attention.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/droppath.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/droppath.cpython-311.pyc
deleted file mode 100644
index 07d1d96d0f04c3d9357811c2754f9391676d2488..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/droppath.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/feedforward_network.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/feedforward_network.cpython-311.pyc
deleted file mode 100644
index e765a90004ee59c29ec6ca9a441e3dd4fad71690..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/feedforward_network.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/flash_attention.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/flash_attention.cpython-311.pyc
deleted file mode 100644
index a348f9c5d9bc6c661d120468ac0afa4aed954b3e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/flash_attention.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/multihead_attention.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/multihead_attention.cpython-311.pyc
deleted file mode 100644
index fff3312e62b39834849fa3da5ca4597add636ad3..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/multihead_attention.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/multiway_network.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/multiway_network.cpython-311.pyc
deleted file mode 100644
index 7d82489492ffdc99708912038542a911194b85e5..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/multiway_network.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/relative_position_bias.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/relative_position_bias.cpython-311.pyc
deleted file mode 100644
index 5c54ffed73d814db2f326f0351efd568b1c7bd9f..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/relative_position_bias.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/utils.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/utils.cpython-311.pyc
deleted file mode 100644
index e39033c3ce5325d20cca1453fa585957dc5b39ab..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/__pycache__/xpos_relative_position.cpython-311.pyc b/code/xtuner/model/torchscale/component/__pycache__/xpos_relative_position.cpython-311.pyc
deleted file mode 100644
index c65ba8731584d8067674863a5cbe56f372f667cb..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/__pycache__/xpos_relative_position.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/custom_dilated_attention.py b/code/xtuner/model/torchscale/component/custom_dilated_attention.py
deleted file mode 100644
index 3f726c4c0391d3f410d25576851acf4f1c4dc142..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/custom_dilated_attention.py
+++ /dev/null
@@ -1,233 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from .multihead_attention import MultiheadAttention
-from .utils import padding_to_multiple_of, all_gather_func, get_data_parallel_rank, get_data_parallel_world_size
-
-
-class DilatedAttention(MultiheadAttention):
-
- def dense_to_sparse(self, x, ratio):
- length = x.size(1)
- padding = padding_to_multiple_of(length, ratio)
- head_padding = padding_to_multiple_of(self.num_heads, ratio)
-
- if padding > 0 or head_padding > 0:
- x = F.pad(x, (0, 0, 0, head_padding, 0, padding), value = 0.)
-
- x = rearrange(x, 'b (l r1) (r2 h) d -> b l h d r1 r2', r1=ratio, r2=ratio)
- x = torch.diagonal(x, offset=0, dim1=4, dim2=5)
- x = rearrange(x, 'b l h d r -> b l (r h) d')
-
- if head_padding > 0:
- x = x[:, :, :self.num_heads]
-
- return x
-
- def sparse_to_dense(self, out, lse, ratio):
- head_padding = padding_to_multiple_of(self.num_heads, ratio)
-
- if head_padding > 0:
- out = F.pad(out, (0, 0, 0, head_padding), value = 0.)
- lse = F.pad(lse, (0, 0, 0, head_padding), value = -1e8)
-
- out = rearrange(out, 'b l (r h) d -> b l h d r', r=ratio)
- out = torch.diag_embed(out, offset=0, dim1=4, dim2=5)
- out = rearrange(out, 'b l h d r1 r2 -> b (r2 h) (l r1) d', r1=ratio, r2=ratio)
-
- lse = rearrange(lse, 'b (r h) l -> b l h r', r=ratio)
- lse = torch.diag_embed(lse, offset=0, dim1=3, dim2=4)
- lse = lse.masked_fill_(lse==0, -1e8)
- lse = rearrange(lse, 'b l h r1 r2 -> b (r2 h) (l r1) 1', r1=ratio, r2=ratio)
-
- if head_padding > 0:
- out = out[:, :self.num_heads]
- lse = lse[:, :self.num_heads]
-
- return out, lse
-
- def gather_kv(self, x, sl, seq_len, is_causal=True):
- bsz = x.size(0)
- assert sl % seq_len == 0
- num_rank_per_segment = sl // seq_len
-
- x = all_gather_func(x)
- current_rank = get_data_parallel_rank()
- x = rearrange(x, '(w b) l h d -> w b l h d', b=bsz)
-
- if is_causal:
- if current_rank > 0:
- x = x[:current_rank]
- else:
- x = x[:1] * 0
-
- current_segment = current_rank // num_rank_per_segment * num_rank_per_segment
- x = x[current_segment:current_segment+num_rank_per_segment]
-
- x = rearrange(x, 'w b l h d -> b (w l) h d')
- return x
-
- def gathering(self, x, dr, sl, is_causal=True, offset=0, is_kv=False, seq_parall=True):
-
- curr_x = x
- if offset > 0:
- curr_x = F.pad(curr_x, (0, 0, 0, 0, offset % sl, 0), value=0.)
- seq_len = curr_x.size(1)
- should_gather_kv = is_kv and (get_data_parallel_world_size() > 1) and (sl > seq_len) and seq_parall
- _sl = sl
- sl = min(sl, seq_len)
- padding = padding_to_multiple_of(seq_len, sl)
-
- if padding > 0:
- curr_x = F.pad(curr_x, (0, 0, 0, 0, 0, padding), value = 0.)
-
- curr_x = rearrange(curr_x, 'b (n g) h d -> (b n) g h d', g=sl)
- curr_x = self.dense_to_sparse(curr_x, dr)
-
- if should_gather_kv:
- curr_x = self.gather_kv(curr_x, _sl, seq_len, is_causal)
-
- curr_x = rearrange(curr_x, 'b l h d -> (b h) l d')
-
- return curr_x
-
- def scattering(self, outs, lses, seq_len, bsz, offset=0):
- assert len(outs) == len(lses)
- assert len(outs) % len(self.args.dilated_ratio) == 0
- all_outs, all_lses = [], []
- drs = self.args.dilated_ratio
- if len(outs) > len(drs):
- drs = drs * (len(outs) // len(drs))
-
- for dr, o, lse in zip(drs, outs, lses):
- o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
- o, lse = self.sparse_to_dense(o, lse, dr)
- o = rearrange(o, '(b n) h g d -> (b h) (n g) d', b=bsz)
- lse = rearrange(lse, '(b n) h g 1 -> (b h) (n g) 1', b=bsz)
- o = o[:, offset:offset+seq_len]
- lse = lse[:, offset:offset+seq_len]
-
- all_outs.append(o)
- all_lses.append(lse)
-
- with torch.no_grad():
- # added by Hanwen, replace nan with 0
- for lse in all_lses:
- if torch.isnan(lse).any():
- print("Warning: Flash Attention 2 has NaN softmax_lse")
- # replace nan with 0
- lse = lse.masked_fill_(torch.isnan(lse), 0)
- max_lse = torch.stack(all_lses, dim=0)
- max_lse = max_lse.max(0)[0]
- all_lses = [torch.exp(lse-max_lse) for lse in all_lses]
- lse_sum = torch.stack(all_lses, dim=0).sum(0)
- all_lses = [lse / lse_sum for lse in all_lses]
-
- out = 0
- for o, lse in zip(all_outs, all_lses):
- out += o * lse.type_as(o)
- out = rearrange(out, '(b h) l d -> b l (h d)', h=self.num_heads)
-
- return out
-
- def forward(
- self,
- query,
- key,
- value,
- incremental_state=None,
- key_padding_mask=None,
- attn_mask=None,
- rel_pos=None,
- is_first_step=False,
- is_causal=False,
- ):
- assert self.args.flash_attention
- assert rel_pos is None
- bsz, tgt_len, embed_dim = query.size()
- src_len = tgt_len
- assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
-
- key_bsz, src_len, _ = key.size()
- assert key_bsz == bsz, f"{query.size(), key.size()}"
- assert value is not None
- assert bsz, src_len == value.shape[:2]
-
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
-
- q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
- k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
- v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
-
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
-
- if incremental_state is not None:
- if "prev_key" in incremental_state:
- prev_key = incremental_state["prev_key"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- prev_value = incremental_state["prev_value"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- k = torch.cat([prev_key, k], dim=1)
- v = torch.cat([prev_value, v], dim=1)
- incremental_state["prev_key"] = k.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- incremental_state["prev_value"] = v.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- src_len = k.size(1)
-
- if self.xpos is not None:
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
- k = self.xpos(k, offset=0, downscale=True)
- q = self.xpos(q, offset=offset, downscale=False)
-
- q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
- k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
- v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
-
- # added by Hanwen, split key_padding_mask
- if key_padding_mask is not None:
- key_padding_mask = key_padding_mask.view(bsz, src_len, 1, 1).expand(-1, -1, self.num_heads, -1)
-
- outs, lses = [], []
- for sl, dr in zip(self.args.segment_length, self.args.dilated_ratio):
- ki = self.gathering(k, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
- vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
- qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel)
-
- # added by Hanwen, split key_padding_mask
- if key_padding_mask is not None:
- key_padding_mask_i = self.gathering(key_padding_mask, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
- else:
- key_padding_mask_i = None
-
- out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask_i, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
-
- outs.append(out)
- lses.append(lse)
-
- attn = self.scattering(outs, lses, tgt_len, bsz, offset=offset)
-
- if self.inner_attn_ln is not None:
- attn = self.inner_attn_ln(attn)
-
- attn = self.out_proj(attn)
-
- return attn, None
diff --git a/code/xtuner/model/torchscale/component/custom_flash_attention.py b/code/xtuner/model/torchscale/component/custom_flash_attention.py
deleted file mode 100644
index ecfe61c3fb4c7951dc4bce57716874232107fa04..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/custom_flash_attention.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-
-from typing import Any, Optional
-import torch
-
-
-if torch.cuda.is_available():
- try:
- if torch.cuda.get_device_capability()[0] > 7:
- from einops import rearrange
- from flash_attn.bert_padding import pad_input, unpad_input
- from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as _flash_attn_varlen_qkvpacked_func
-
- '''
- The official implementation of using Flash Attention 2 in LongNet
- '''
-
- def flash_attn_varlen_func(q, k, v, dropout=0.0, bias=None, key_padding_mask=None, softmax_scale=None, is_causal=False):
- # q, k, v: [b, s, h, d]
- assert bias is None
-
- # stack the concatentate q, k, v in to [b, s, 3, h, d]
- qkv = torch.stack([q, k, v], dim=2)
- batch_size, seqlen, _, nheads, dim = qkv.shape
-
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
- output_unpad, lse, _ = _flash_attn_varlen_qkvpacked_func(
- x_unpad, cu_seqlens, max_s, dropout,
- softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True
- )
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
- indices, batch_size, seqlen),
- 'b s (h d) -> b s h d', h=nheads)
-
- if torch.isnan(output).any() or torch.isnan(lse).any():
- print("Warning: Flash Attention 2 has NaN output")
- state_dict = {'qkv': qkv, 'out': output, 'softmax_lse': lse,
- 'key_padding_mask': key_padding_mask,
- 'cu_seqlens': cu_seqlens, 'max_seqlen': max_s,
- 'dropout_p': dropout, 'softmax_scale': softmax_scale}
- torch.save(state_dict, 'nan_repro.pt')
-
- return output, lse
-
- def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
- assert bias is None
- attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
- return attn, lse
-
- print("\033[92mUsing Flash Attention 2\033[0m")
- else:
- print("\033[91mWarning: Flash Attention 2 is not successfully loaded. Using XFormers instead.\033[0m")
- from xformers.ops.fmha import (
- cutlass,
- Inputs,
- Context,
- _memory_efficient_attention_forward_requires_grad,
- _memory_efficient_attention_backward,
- LowerTriangularMask,
- )
-
- class FlashAttnFunc(torch.autograd.Function):
- @staticmethod
- # type: ignore
- def forward(ctx, q, k, v, dropout=0.0, bias=None, key_padding_mask=None, softmax_scale=None, is_causal=False):
- if is_causal:
- assert bias is None
- attn_bias = LowerTriangularMask()
- else:
- attn_bias = bias
-
- inp = Inputs(
- query=q,
- key=k,
- value=v,
- attn_bias=attn_bias,
- p=dropout,
- scale=softmax_scale,
- )
- op_fw = cutlass.FwOp
- op_bw = cutlass.BwOp
-
- out, op_ctx = _memory_efficient_attention_forward_requires_grad(
- inp=inp, op=op_fw
- )
-
- # Saving attn_bias is a bit complicated, as the
- # torch part should go in `save_for_backward`
- if isinstance(inp.attn_bias, torch.Tensor):
- attn_bias_tensor = inp.attn_bias
- attn_bias_ctx = None
- else:
- attn_bias_tensor = None
- attn_bias_ctx = inp.attn_bias
-
- ctx.save_for_backward(
- inp.query,
- inp.key,
- inp.value,
- op_ctx.out,
- op_ctx.lse,
- )
- ctx.rng_state = op_ctx.rng_state
- ctx.attn_bias_tensor = attn_bias_tensor
- if op_ctx.op_bw is not None:
- if op_bw is not None and op_bw is not op_ctx.op_bw:
- raise ValueError(
- f"Specified op_bw={op_bw.NAME}, but forward op "
- f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
- )
- op_bw = op_ctx.op_bw
- ctx.op_fw = op_fw
- ctx.op_bw = op_bw
- ctx.p = inp.p
-
- ctx.scale = inp.scale
- ctx.attn_bias_ctx = attn_bias_ctx
- return out, op_ctx.lse
-
- @staticmethod
- def deserialize_bias(
- attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
- ) -> Any:
- if attn_bias_tensor is None:
- return attn_bias_ctx
- return attn_bias_tensor
-
- @classmethod
- @torch.autograd.function.once_differentiable
- def backward(cls, ctx, grad, dlse):
- # Re-create context
- query, key, value, out, lse = ctx.saved_tensors
- attn_bias_tensor = ctx.attn_bias_tensor
- rng_state = ctx.rng_state
- inp = Inputs(
- query=query,
- key=key,
- value=value,
- attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
- p=ctx.p,
- scale=ctx.scale,
- )
- op_ctx = Context(
- lse=lse,
- out=out,
- rng_state=rng_state,
- )
- grads = _memory_efficient_attention_backward(
- ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
- )
- return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
-
- flash_attn_func = FlashAttnFunc.apply
- except ModuleNotFoundError:
- flash_attn_func = None
-else:
- flash_attn_func = None
diff --git a/code/xtuner/model/torchscale/component/custom_multihead_attention.py b/code/xtuner/model/torchscale/component/custom_multihead_attention.py
deleted file mode 100644
index 80ad600413e6e7b4f5fa698e4d22c748a2045a1d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/custom_multihead_attention.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from einops import rearrange
-try:
- from apex.normalization import FusedLayerNorm as LayerNorm
-except ModuleNotFoundError:
- from torch.nn import LayerNorm
-
-from .multiway_network import MultiwayWrapper
-from .xpos_relative_position import XPOS
-from .flash_attention import flash_attn_func, flash_attn_varlen_func
-
-
-class MultiheadAttention(nn.Module):
- def __init__(
- self,
- args,
- embed_dim,
- num_heads,
- dropout=0.0,
- self_attention=False,
- encoder_decoder_attention=False,
- subln=False,
- ):
- super().__init__()
- self.args = args
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.head_dim = embed_dim // num_heads
- self.scaling = self.head_dim**-0.5
- self.dropout = dropout
-
- self.self_attention = self_attention
- self.encoder_decoder_attention = encoder_decoder_attention
- assert self.self_attention ^ self.encoder_decoder_attention
-
- self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.out_proj = MultiwayWrapper(
- args, nn.Linear(embed_dim, embed_dim, bias=True)
- )
- self.inner_attn_ln = (
- MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
- if subln and self.self_attention
- else None
- )
- self.dropout_module = torch.nn.Dropout(dropout)
- self.xpos = (
- XPOS(self.head_dim, args.xpos_scale_base)
- if args.xpos_rel_pos and self.self_attention
- else None
- )
-
- def reset_parameters(self):
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.out_proj.weight)
- nn.init.constant_(self.out_proj.bias, 0.0)
-
- def attention_ops(self, q, k, v, key_padding_mask=None, attn_mask=None, rel_pos=None, is_causal=False):
- if not self.args.flash_attention:
- q *= self.scaling
- attn_weights = torch.bmm(q, k.transpose(1, 2))
-
- if attn_mask is not None:
- attn_weights = torch.nan_to_num(attn_weights)
- attn_mask = attn_mask.unsqueeze(0)
- attn_weights += attn_mask
-
- if key_padding_mask is not None:
- attn_weights = rearrange(attn_weights, '(b h) t s -> b h t s', h=self.num_heads)
- attn_weights = attn_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
- float("-inf"),
- )
- attn_weights = rearrange(attn_weights, 'b h t s -> (b h) t s')
-
- if rel_pos is not None:
- rel_pos = rel_pos.view(attn_weights.size())
- attn_weights = attn_weights + rel_pos
-
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
- attn_weights
- )
- attn_probs = self.dropout_module(attn_weights)
-
- attn = torch.bmm(attn_probs, v)
- attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads)
- else:
- assert flash_attn_func is not None
- assert flash_attn_varlen_func is not None
- assert rel_pos is None
- q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
- k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
- v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
- # added by Hanwen
-
- if key_padding_mask is not None:
- # if use key_padding_mask, then use the flash attention function supporting variable length
- key_padding_mask = rearrange(key_padding_mask, '(b h) l d -> b l h d', h=self.num_heads)
- # h head is redundant, so pick the first one
- key_padding_mask = key_padding_mask[:, :, 0, 0]
- # assert 1 means padding in key padding mask
- assert key_padding_mask[0, 0] == 0
- # convert the key_padding_mask to be compatible with flash attention
- key_padding_mask = 1 - key_padding_mask
- attn, lse = flash_attn_varlen_func(q, k, v, self.dropout, None, key_padding_mask, None, is_causal)
- else:
- attn, lse = flash_attn_func(q, k, v, self.dropout, None, None, is_causal)
-
- attn = rearrange(attn, 'b l h d -> b l (h d)')
- attn_weights = lse[:, :, :attn.size(1)]
-
- return attn, attn_weights
-
- def forward(
- self,
- query,
- key,
- value,
- incremental_state=None,
- key_padding_mask=None,
- attn_mask=None,
- rel_pos=None,
- is_first_step=False,
- is_causal=False,
- ):
- bsz, tgt_len, embed_dim = query.size()
- src_len = tgt_len
- assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
-
- key_bsz, src_len, _ = key.size()
- assert key_bsz == bsz, f"{query.size(), key.size()}"
- assert value is not None
- assert bsz, src_len == value.shape[:2]
-
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
-
- q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
- k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
- v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
-
- if incremental_state is not None:
- if "prev_key" in incremental_state:
- prev_key = incremental_state["prev_key"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- prev_value = incremental_state["prev_value"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- k = torch.cat([prev_key, k], dim=1)
- v = torch.cat([prev_value, v], dim=1)
- incremental_state["prev_key"] = k.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- incremental_state["prev_value"] = v.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- src_len = k.size(1)
-
- if self.xpos is not None:
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
- k = self.xpos(k, offset=0, downscale=True)
- q = self.xpos(q, offset=offset, downscale=False)
-
- attn, attn_weights = self.attention_ops(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
-
- if self.inner_attn_ln is not None:
- attn = self.inner_attn_ln(attn)
-
- attn = self.out_proj(attn)
-
- return attn, attn_weights
diff --git a/code/xtuner/model/torchscale/component/dilated_attention.py b/code/xtuner/model/torchscale/component/dilated_attention.py
deleted file mode 100644
index 55f6d4efb598c474b9f617801ab4689fe1cd592a..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/dilated_attention.py
+++ /dev/null
@@ -1,217 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from .multihead_attention import MultiheadAttention
-from .utils import padding_to_multiple_of, all_gather_func, get_data_parallel_rank, get_data_parallel_world_size
-
-
-class DilatedAttention(MultiheadAttention):
-
- def dense_to_sparse(self, x, ratio):
- length = x.size(1)
- padding = padding_to_multiple_of(length, ratio)
- head_padding = padding_to_multiple_of(self.num_heads, ratio)
-
- if padding > 0 or head_padding > 0:
- x = F.pad(x, (0, 0, 0, head_padding, 0, padding), value = 0.)
-
- x = rearrange(x, 'b (l r1) (r2 h) d -> b l h d r1 r2', r1=ratio, r2=ratio)
- x = torch.diagonal(x, offset=0, dim1=4, dim2=5)
- x = rearrange(x, 'b l h d r -> b l (r h) d')
-
- if head_padding > 0:
- x = x[:, :, :self.num_heads]
-
- return x
-
- def sparse_to_dense(self, out, lse, ratio):
- head_padding = padding_to_multiple_of(self.num_heads, ratio)
-
- if head_padding > 0:
- out = F.pad(out, (0, 0, 0, head_padding), value = 0.)
- lse = F.pad(lse, (0, 0, 0, head_padding), value = -1e8)
-
- out = rearrange(out, 'b l (r h) d -> b l h d r', r=ratio)
- out = torch.diag_embed(out, offset=0, dim1=4, dim2=5)
- out = rearrange(out, 'b l h d r1 r2 -> b (r2 h) (l r1) d', r1=ratio, r2=ratio)
-
- lse = rearrange(lse, 'b (r h) l -> b l h r', r=ratio)
- lse = torch.diag_embed(lse, offset=0, dim1=3, dim2=4)
- lse = lse.masked_fill_(lse==0, -1e8)
- lse = rearrange(lse, 'b l h r1 r2 -> b (r2 h) (l r1) 1', r1=ratio, r2=ratio)
-
- if head_padding > 0:
- out = out[:, :self.num_heads]
- lse = lse[:, :self.num_heads]
-
- return out, lse
-
- def gather_kv(self, x, sl, seq_len, is_causal=True):
- bsz = x.size(0)
- assert sl % seq_len == 0
- num_rank_per_segment = sl // seq_len
-
- x = all_gather_func(x)
- current_rank = get_data_parallel_rank()
- x = rearrange(x, '(w b) l h d -> w b l h d', b=bsz)
-
- if is_causal:
- if current_rank > 0:
- x = x[:current_rank]
- else:
- x = x[:1] * 0
-
- current_segment = current_rank // num_rank_per_segment * num_rank_per_segment
- x = x[current_segment:current_segment+num_rank_per_segment]
-
- x = rearrange(x, 'w b l h d -> b (w l) h d')
- return x
-
- def gathering(self, x, dr, sl, is_causal=True, offset=0, is_kv=False, seq_parall=True):
-
- curr_x = x
- if offset > 0:
- curr_x = F.pad(curr_x, (0, 0, 0, 0, offset % sl, 0), value=0.)
- seq_len = curr_x.size(1)
- should_gather_kv = is_kv and (get_data_parallel_world_size() > 1) and (sl > seq_len) and seq_parall
- _sl = sl
- sl = min(sl, seq_len)
- padding = padding_to_multiple_of(seq_len, sl)
-
- if padding > 0:
- curr_x = F.pad(curr_x, (0, 0, 0, 0, 0, padding), value = 0.)
-
- curr_x = rearrange(curr_x, 'b (n g) h d -> (b n) g h d', g=sl)
- curr_x = self.dense_to_sparse(curr_x, dr)
-
- if should_gather_kv:
- curr_x = self.gather_kv(curr_x, _sl, seq_len, is_causal)
-
- curr_x = rearrange(curr_x, 'b l h d -> (b h) l d')
-
- return curr_x
-
- def scattering(self, outs, lses, seq_len, bsz, offset=0):
- assert len(outs) == len(lses)
- assert len(outs) % len(self.args.dilated_ratio) == 0
- all_outs, all_lses = [], []
- drs = self.args.dilated_ratio
- if len(outs) > len(drs):
- drs = drs * (len(outs) // len(drs))
-
- for dr, o, lse in zip(drs, outs, lses):
- o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
- o, lse = self.sparse_to_dense(o, lse, dr)
- o = rearrange(o, '(b n) h g d -> (b h) (n g) d', b=bsz)
- lse = rearrange(lse, '(b n) h g 1 -> (b h) (n g) 1', b=bsz)
- o = o[:, offset:offset+seq_len]
- lse = lse[:, offset:offset+seq_len]
-
- all_outs.append(o)
- all_lses.append(lse)
-
- with torch.no_grad():
- max_lse = torch.stack(all_lses, dim=0)
- max_lse = max_lse.max(0)[0]
- all_lses = [torch.exp(lse-max_lse) for lse in all_lses]
- lse_sum = torch.stack(all_lses, dim=0).sum(0)
- all_lses = [lse / lse_sum for lse in all_lses]
-
- out = 0
- for o, lse in zip(all_outs, all_lses):
- out += o * lse.type_as(o)
- out = rearrange(out, '(b h) l d -> b l (h d)', h=self.num_heads)
-
- return out
-
- def forward(
- self,
- query,
- key,
- value,
- incremental_state=None,
- key_padding_mask=None,
- attn_mask=None,
- rel_pos=None,
- is_first_step=False,
- is_causal=False,
- ):
- assert self.args.flash_attention
- assert rel_pos is None
- bsz, tgt_len, embed_dim = query.size()
- src_len = tgt_len
- assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
-
- key_bsz, src_len, _ = key.size()
- assert key_bsz == bsz, f"{query.size(), key.size()}"
- assert value is not None
- assert bsz, src_len == value.shape[:2]
-
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
-
- q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
- k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
- v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
-
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
-
- if incremental_state is not None:
- if "prev_key" in incremental_state:
- prev_key = incremental_state["prev_key"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- prev_value = incremental_state["prev_value"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- k = torch.cat([prev_key, k], dim=1)
- v = torch.cat([prev_value, v], dim=1)
- incremental_state["prev_key"] = k.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- incremental_state["prev_value"] = v.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- src_len = k.size(1)
-
- if self.xpos is not None:
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
- k = self.xpos(k, offset=0, downscale=True)
- q = self.xpos(q, offset=offset, downscale=False)
-
- q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
- k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
- v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
-
- outs, lses = [], []
- for sl, dr in zip(self.args.segment_length, self.args.dilated_ratio):
- ki = self.gathering(k, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
- vi = self.gathering(v, dr, sl, is_causal=is_causal, offset=0, is_kv=True, seq_parall=self.args.seq_parallel)
- qi = self.gathering(q, dr, sl, is_causal=is_causal, offset=offset, is_kv=False, seq_parall=self.args.seq_parallel)
-
- out, lse = self.attention_ops(qi, ki, vi, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
-
- outs.append(out)
- lses.append(lse)
-
- attn = self.scattering(outs, lses, tgt_len, bsz, offset=offset)
-
- if self.inner_attn_ln is not None:
- attn = self.inner_attn_ln(attn)
-
- attn = self.out_proj(attn)
-
- return attn, None
diff --git a/code/xtuner/model/torchscale/component/droppath.py b/code/xtuner/model/torchscale/component/droppath.py
deleted file mode 100644
index 18c06440816d67402470f8a0876e9c4806d172fc..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/droppath.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch.nn as nn
-from timm.models.layers import drop_path
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
-
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
-
- def extra_repr(self):
- return "p={}".format(self.drop_prob)
diff --git a/code/xtuner/model/torchscale/component/embedding.py b/code/xtuner/model/torchscale/component/embedding.py
deleted file mode 100644
index e633d5a692646ecde06885c11fa87efd048e9db1..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/embedding.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class VisionLanguageEmbedding(nn.Module):
- def __init__(self, text_embed, vision_embed):
- super().__init__()
- self.text_embed = text_embed
- self.vision_embed = vision_embed
-
- def forward(self, textual_tokens, visual_tokens, **kwargs):
- if textual_tokens is None:
- return self.vision_embed(visual_tokens)
-
- if visual_tokens is None:
- return self.text_embed(textual_tokens)
-
- x1 = self.vision_embed(visual_tokens)
- x2 = self.text_embed(textual_tokens)
-
- return torch.cat([x1, x2], dim=1)
-
-
-class VisionEmbedding(nn.Module):
- """Image to Patch Embedding"""
-
- def __init__(
- self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- embed_dim=768,
- contain_mask_token=False,
- prepend_cls_token=False,
- ):
- super().__init__()
- img_size = (img_size, img_size)
- patch_size = (patch_size, patch_size)
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
- self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
- self.img_size = img_size
- self.patch_size = patch_size
- self.num_patches = num_patches
-
- self.proj = nn.Conv2d(
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
- )
-
- if contain_mask_token:
- self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- else:
- self.mask_token = None
-
- if prepend_cls_token:
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- else:
- self.cls_token = None
-
- def num_position_embeddings(self):
- if self.cls_token is None:
- return self.num_patches
- else:
- return self.num_patches + 1
-
- def forward(self, x, masked_position=None, **kwargs):
- B, C, H, W = x.shape
- assert (
- H == self.img_size[0] and W == self.img_size[1]
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2)
-
- batch_size, seq_len, _ = x.size()
-
- if masked_position is not None:
- assert self.mask_token is not None
- mask_token = self.mask_token.expand(batch_size, seq_len, -1)
- w = masked_position.unsqueeze(-1).type_as(mask_token)
- x = x * (1 - w) + mask_token * w
-
- if self.cls_token is not None:
- cls_tokens = self.cls_token.expand(
- batch_size, -1, -1
- ) # stole cls_tokens impl from Phil Wang, thanks
- x = torch.cat((cls_tokens, x), dim=1)
-
- return x
-
-
-class TextEmbedding(nn.Embedding):
- def reset_parameters(self):
- nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
- self._fill_padding_idx_with_zero()
-
-
-class PositionalEmbedding(nn.Embedding):
- def forward(
- self,
- x,
- positions=None,
- **kwargs,
- ):
- if positions is None:
- # being consistent with Fairseq, which starts from 2.
- positions = (
- torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
- )
- return F.embedding(
- positions,
- self.weight,
- self.padding_idx,
- self.max_norm,
- self.norm_type,
- self.scale_grad_by_freq,
- self.sparse,
- )
diff --git a/code/xtuner/model/torchscale/component/feedforward_network.py b/code/xtuner/model/torchscale/component/feedforward_network.py
deleted file mode 100644
index 9d0295d94a61b4c2379f460c508ef5b50960a43e..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/feedforward_network.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-try:
- from apex.normalization import FusedLayerNorm as LayerNorm
-except ModuleNotFoundError:
- from torch.nn import LayerNorm
-
-
-from .xmoe.global_groups import get_moe_group
-
-
-class set_torch_seed(object):
- def __init__(self, seed):
- assert isinstance(seed, int)
- self.rng_state = self.get_rng_state()
-
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
-
- def get_rng_state(self):
- state = {"torch_rng_state": torch.get_rng_state()}
- if torch.cuda.is_available():
- state["cuda_rng_state"] = torch.cuda.get_rng_state()
- return state
-
- def set_rng_state(self, state):
- torch.set_rng_state(state["torch_rng_state"])
- if torch.cuda.is_available():
- torch.cuda.set_rng_state(state["cuda_rng_state"])
-
- def __enter__(self):
- return self
-
- def __exit__(self, *exc):
- self.set_rng_state(self.rng_state)
-
-
-def make_experts(args, embed_dim, expert_ffn_dim):
- world_size = (
- 1
- if not torch.distributed.is_initialized()
- else torch.distributed.get_world_size()
- )
- expert_list = []
- ddp_rank = args.ddp_rank
- start_seed = torch.randint(1000000, (1,)).item()
- # at least as many experts than gpus
- if args.moe_expert_count >= world_size:
- assert (
- args.moe_expert_count % world_size == 0
- ), f"{args.moe_expert_count}, {world_size}"
- local_moe_expert_count = args.moe_expert_count // world_size
- for i in range(local_moe_expert_count):
- with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
- expert_list.append(
- FeedForwardNetwork(
- embed_dim,
- expert_ffn_dim,
- args.activation_fn,
- args.dropout,
- args.activation_dropout,
- args.layernorm_eps,
- args.subln,
- )
- )
- else:
- assert (
- world_size % args.moe_expert_count == 0
- ), f"{world_size}, {args.moe_expert_count}"
-
- moe_idx, _ = get_moe_group(args.moe_expert_count)
-
- with set_torch_seed(start_seed + moe_idx):
- expert_list.append(
- FeedForwardNetwork(
- embed_dim,
- expert_ffn_dim,
- args.activation_fn,
- args.dropout,
- args.activation_dropout,
- args.layernorm_eps,
- args.subln,
- )
- )
- experts = nn.ModuleList(expert_list)
- return experts
-
-
-def get_activation_fn(activation):
- if activation == "relu":
- return F.relu
- elif activation == "gelu":
- return F.gelu
- elif activation == "swish":
- return F.silu
- else:
- raise NotImplementedError
-
-
-class FeedForwardNetwork(nn.Module):
- def __init__(
- self,
- embed_dim,
- ffn_dim,
- activation_fn,
- dropout,
- activation_dropout,
- layernorm_eps,
- subln=False,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.activation_fn = get_activation_fn(activation=str(activation_fn))
- self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
- self.dropout_module = torch.nn.Dropout(dropout)
- self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
- self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
- self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
-
- def reset_parameters(self):
- self.fc1.reset_parameters()
- self.fc2.reset_parameters()
- if self.ffn_layernorm is not None:
- self.ffn_layernorm.reset_parameters()
-
- def forward(self, x):
- x_shape = x.shape
- x = x.reshape(-1, x.size(-1))
- x = self.fc1(x)
- x = self.activation_fn(x.float()).type_as(x)
- x = self.activation_dropout_module(x)
- if self.ffn_layernorm is not None:
- x = self.ffn_layernorm(x)
- x = self.fc2(x)
- x = x.view(x_shape)
- x = self.dropout_module(x)
- return x
diff --git a/code/xtuner/model/torchscale/component/flash_attention.py b/code/xtuner/model/torchscale/component/flash_attention.py
deleted file mode 100644
index 8c76a2d0469ff3c08fe144cc221b2cc7fdcdf1ae..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/flash_attention.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-
-from typing import Any, Optional
-import torch
-
-# if torch.cuda.is_available():
-# try:
-# if torch.cuda.get_device_capability()[0] > 7:
-# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
-
-# def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
-# assert bias is None
-# attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
-# return attn, lse
-
-# else:
-# from xformers.ops.fmha import (
-# cutlass,
-# Inputs,
-# Context,
-# _memory_efficient_attention_forward_requires_grad,
-# _memory_efficient_attention_backward,
-# LowerTriangularMask,
-# )
-
-# class FlashAttnFunc(torch.autograd.Function):
-# @staticmethod
-# # type: ignore
-# def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
-# if is_causal:
-# assert bias is None
-# attn_bias = LowerTriangularMask()
-# else:
-# attn_bias = bias
-
-# inp = Inputs(
-# query=q,
-# key=k,
-# value=v,
-# attn_bias=attn_bias,
-# p=dropout,
-# scale=softmax_scale,
-# )
-# op_fw = cutlass.FwOp
-# op_bw = cutlass.BwOp
-
-# out, op_ctx = _memory_efficient_attention_forward_requires_grad(
-# inp=inp, op=op_fw
-# )
-
-# # Saving attn_bias is a bit complicated, as the
-# # torch part should go in `save_for_backward`
-# if isinstance(inp.attn_bias, torch.Tensor):
-# attn_bias_tensor = inp.attn_bias
-# attn_bias_ctx = None
-# else:
-# attn_bias_tensor = None
-# attn_bias_ctx = inp.attn_bias
-
-# ctx.save_for_backward(
-# inp.query,
-# inp.key,
-# inp.value,
-# op_ctx.out,
-# op_ctx.lse,
-# )
-# ctx.rng_state = op_ctx.rng_state
-# ctx.attn_bias_tensor = attn_bias_tensor
-# if op_ctx.op_bw is not None:
-# if op_bw is not None and op_bw is not op_ctx.op_bw:
-# raise ValueError(
-# f"Specified op_bw={op_bw.NAME}, but forward op "
-# f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
-# )
-# op_bw = op_ctx.op_bw
-# ctx.op_fw = op_fw
-# ctx.op_bw = op_bw
-# ctx.p = inp.p
-
-# ctx.scale = inp.scale
-# ctx.attn_bias_ctx = attn_bias_ctx
-# return out, op_ctx.lse
-
-# @staticmethod
-# def deserialize_bias(
-# attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
-# ) -> Any:
-# if attn_bias_tensor is None:
-# return attn_bias_ctx
-# return attn_bias_tensor
-
-# @classmethod
-# @torch.autograd.function.once_differentiable
-# def backward(cls, ctx, grad, dlse):
-# # Re-create context
-# query, key, value, out, lse = ctx.saved_tensors
-# attn_bias_tensor = ctx.attn_bias_tensor
-# rng_state = ctx.rng_state
-# inp = Inputs(
-# query=query,
-# key=key,
-# value=value,
-# attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
-# p=ctx.p,
-# scale=ctx.scale,
-# )
-# op_ctx = Context(
-# lse=lse,
-# out=out,
-# rng_state=rng_state,
-# )
-# grads = _memory_efficient_attention_backward(
-# ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
-# )
-# return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
-
-# flash_attn_func = FlashAttnFunc.apply
-# except ModuleNotFoundError:
-# flash_attn_func = None
-# else:
-# flash_attn_func = None
-
-
-if torch.cuda.is_available():
- try:
- if torch.cuda.get_device_capability()[0] > 7:
- from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
-
- def flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
- assert bias is None
- attn, lse, _ = _flash_attn_func(q, k, v, dropout_p=dropout, softmax_scale=softmax_scale, causal=is_causal, return_attn_probs=True)
- return attn, lse
-
- else:
- from xformers.ops.fmha import (
- cutlass,
- Inputs,
- Context,
- _memory_efficient_attention_forward_requires_grad,
- _memory_efficient_attention_backward,
- LowerTriangularMask,
- )
-
- class FlashAttnFunc(torch.autograd.Function):
- @staticmethod
- # type: ignore
- def forward(ctx, q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False):
- if is_causal:
- assert bias is None
- attn_bias = LowerTriangularMask()
- else:
- attn_bias = bias
-
- inp = Inputs(
- query=q,
- key=k,
- value=v,
- attn_bias=attn_bias,
- p=dropout,
- scale=softmax_scale,
- )
- op_fw = cutlass.FwOp
- op_bw = cutlass.BwOp
-
- out, op_ctx = _memory_efficient_attention_forward_requires_grad(
- inp=inp, op=op_fw
- )
-
- # Saving attn_bias is a bit complicated, as the
- # torch part should go in `save_for_backward`
- if isinstance(inp.attn_bias, torch.Tensor):
- attn_bias_tensor = inp.attn_bias
- attn_bias_ctx = None
- else:
- attn_bias_tensor = None
- attn_bias_ctx = inp.attn_bias
-
- ctx.save_for_backward(
- inp.query,
- inp.key,
- inp.value,
- op_ctx.out,
- op_ctx.lse,
- )
- ctx.rng_state = op_ctx.rng_state
- ctx.attn_bias_tensor = attn_bias_tensor
- if op_ctx.op_bw is not None:
- if op_bw is not None and op_bw is not op_ctx.op_bw:
- raise ValueError(
- f"Specified op_bw={op_bw.NAME}, but forward op "
- f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
- )
- op_bw = op_ctx.op_bw
- ctx.op_fw = op_fw
- ctx.op_bw = op_bw
- ctx.p = inp.p
-
- ctx.scale = inp.scale
- ctx.attn_bias_ctx = attn_bias_ctx
- return out, op_ctx.lse
-
- @staticmethod
- def deserialize_bias(
- attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
- ) -> Any:
- if attn_bias_tensor is None:
- return attn_bias_ctx
- return attn_bias_tensor
-
- @classmethod
- @torch.autograd.function.once_differentiable
- def backward(cls, ctx, grad, dlse):
- # Re-create context
- query, key, value, out, lse = ctx.saved_tensors
- attn_bias_tensor = ctx.attn_bias_tensor
- rng_state = ctx.rng_state
- inp = Inputs(
- query=query,
- key=key,
- value=value,
- attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
- p=ctx.p,
- scale=ctx.scale,
- )
- op_ctx = Context(
- lse=lse,
- out=out,
- rng_state=rng_state,
- )
- grads = _memory_efficient_attention_backward(
- ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
- )
- return grads.dq, grads.dk, grads.dv, None, grads.db, None, None
-
- flash_attn_func = FlashAttnFunc.apply
- except ModuleNotFoundError:
- flash_attn_func = None
-else:
- flash_attn_func = None
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/component/gate_linear_unit.py b/code/xtuner/model/torchscale/component/gate_linear_unit.py
deleted file mode 100644
index ecc9b34f194acd6f14b86a7d541599124fbf4b9d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/gate_linear_unit.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .feedforward_network import get_activation_fn
-
-
-class GLU(nn.Module):
- def __init__(
- self,
- embed_dim,
- ffn_dim,
- activation_fn,
- dropout,
- activation_dropout,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.activation_fn = get_activation_fn(activation=str(activation_fn))
- self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
- self.dropout_module = torch.nn.Dropout(dropout)
- self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
- self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
- self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
-
- def reset_parameters(self):
- self.fc1.reset_parameters()
- self.fc2.reset_parameters()
- self.gate.reset_parameters()
-
- def forward(self, x):
- x_shape = x.shape
- x = x.reshape(-1, x.size(-1))
- g = self.gate(x)
- x = self.fc1(x)
- x = self.activation_fn(x.float()).type_as(x) * g
- x = self.activation_dropout_module(x)
- x = self.fc2(x)
- x = x.view(x_shape)
- x = self.dropout_module(x)
- return x
diff --git a/code/xtuner/model/torchscale/component/multihead_attention.py b/code/xtuner/model/torchscale/component/multihead_attention.py
deleted file mode 100644
index 33c044cdeb547c741bf9daf927b2943b6a3ea8b5..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/multihead_attention.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from einops import rearrange
-try:
- from apex.normalization import FusedLayerNorm as LayerNorm
-except ModuleNotFoundError:
- from torch.nn import LayerNorm
-
-from .multiway_network import MultiwayWrapper
-from .xpos_relative_position import XPOS
-from .flash_attention import flash_attn_func
-
-
-class MultiheadAttention(nn.Module):
- def __init__(
- self,
- args,
- embed_dim,
- num_heads,
- dropout=0.0,
- self_attention=False,
- encoder_decoder_attention=False,
- subln=False,
- ):
- super().__init__()
- self.args = args
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.head_dim = embed_dim // num_heads
- self.scaling = self.head_dim**-0.5
- self.dropout = dropout
-
- self.self_attention = self_attention
- self.encoder_decoder_attention = encoder_decoder_attention
- assert self.self_attention ^ self.encoder_decoder_attention
-
- self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
- self.out_proj = MultiwayWrapper(
- args, nn.Linear(embed_dim, embed_dim, bias=True)
- )
- self.inner_attn_ln = (
- MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
- if subln and self.self_attention
- else None
- )
- self.dropout_module = torch.nn.Dropout(dropout)
- self.xpos = (
- XPOS(self.head_dim, args.xpos_scale_base)
- if args.xpos_rel_pos and self.self_attention
- else None
- )
-
- def reset_parameters(self):
- nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
- nn.init.xavier_uniform_(self.out_proj.weight)
- nn.init.constant_(self.out_proj.bias, 0.0)
-
- def attention_ops(self, q, k, v, key_padding_mask=None, attn_mask=None, rel_pos=None, is_causal=False):
- if not self.args.flash_attention:
- q *= self.scaling
- attn_weights = torch.bmm(q, k.transpose(1, 2))
-
- if attn_mask is not None:
- attn_weights = torch.nan_to_num(attn_weights)
- attn_mask = attn_mask.unsqueeze(0)
- attn_weights += attn_mask
-
- if key_padding_mask is not None:
- attn_weights = rearrange(attn_weights, '(b h) t s -> b h t s', h=self.num_heads)
- attn_weights = attn_weights.masked_fill(
- key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
- float("-inf"),
- )
- attn_weights = rearrange(attn_weights, 'b h t s -> (b h) t s')
-
- if rel_pos is not None:
- rel_pos = rel_pos.view(attn_weights.size())
- attn_weights = attn_weights + rel_pos
-
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
- attn_weights
- )
- attn_probs = self.dropout_module(attn_weights)
-
- attn = torch.bmm(attn_probs, v)
- attn = rearrange(attn, '(b h) l d -> b l (h d)', h=self.num_heads)
- else:
- assert flash_attn_func is not None
- assert rel_pos is None
- q = rearrange(q, '(b h) l d -> b l h d', h=self.num_heads)
- k = rearrange(k, '(b h) l d -> b l h d', h=self.num_heads)
- v = rearrange(v, '(b h) l d -> b l h d', h=self.num_heads)
- # flash_attn_func(q, k, v, dropout=0.0, bias=None, softmax_scale=None, is_causal=False)
- attn, lse = flash_attn_func(q, k, v, self.dropout, attn_mask, None, is_causal)
- # attn, lse = flash_attn_func(q, k, v, dropout=self.dropout, is_causal=is_causal)
- attn = rearrange(attn, 'b l h d -> b l (h d)')
- attn_weights = lse[:, :, :attn.size(1)]
-
- return attn, attn_weights
-
- def forward(
- self,
- query,
- key,
- value,
- incremental_state=None,
- key_padding_mask=None,
- attn_mask=None,
- rel_pos=None,
- is_first_step=False,
- is_causal=False,
- ):
- bsz, tgt_len, embed_dim = query.size()
- src_len = tgt_len
- assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
-
- key_bsz, src_len, _ = key.size()
- assert key_bsz == bsz, f"{query.size(), key.size()}"
- assert value is not None
- assert bsz, src_len == value.shape[:2]
-
- q = self.q_proj(query)
- k = self.k_proj(key)
- v = self.v_proj(value)
-
- q = rearrange(q, 'b l (h d) -> (b h) l d', h=self.num_heads)
- k = rearrange(k, 'b l (h d) -> (b h) l d', h=self.num_heads)
- v = rearrange(v, 'b l (h d) -> (b h) l d', h=self.num_heads)
-
- if incremental_state is not None:
- if "prev_key" in incremental_state:
- prev_key = incremental_state["prev_key"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- prev_value = incremental_state["prev_value"].view(
- bsz * self.num_heads, -1, self.head_dim
- )
- k = torch.cat([prev_key, k], dim=1)
- v = torch.cat([prev_value, v], dim=1)
- incremental_state["prev_key"] = k.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- incremental_state["prev_value"] = v.view(
- bsz, self.num_heads, -1, self.head_dim
- )
- src_len = k.size(1)
-
- if self.xpos is not None:
- if incremental_state is not None and not is_first_step:
- offset = src_len - 1
- else:
- offset = 0
- k = self.xpos(k, offset=0, downscale=True)
- q = self.xpos(q, offset=offset, downscale=False)
-
- attn, attn_weights = self.attention_ops(q, k, v, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rel_pos=rel_pos, is_causal=is_causal)
-
- if self.inner_attn_ln is not None:
- attn = self.inner_attn_ln(attn)
-
- attn = self.out_proj(attn)
-
- return attn, attn_weights
diff --git a/code/xtuner/model/torchscale/component/multiscale_retention.py b/code/xtuner/model/torchscale/component/multiscale_retention.py
deleted file mode 100644
index d475ccf79c0691dd55ab0a2c5a6d05010c572c21..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/multiscale_retention.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from .rms_norm import RMSNorm
-
-from .multiway_network import MultiwayWrapper
-
-def rotate_every_two(x):
- x1 = x[:, :, :, ::2]
- x2 = x[:, :, :, 1::2]
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
-
-def duplicate_interleave(m):
- """
- A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
- """
- dim0 = m.shape[0]
- m = m.view(-1, 1) # flatten the matrix
- m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
- m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
- return m
-
-def theta_shift(x, sin, cos):
- return (x * cos) + (rotate_every_two(x) * sin)
-
-def get_activation_fn(activation):
- if activation == "swish":
- return F.silu
- elif activation == "gelu":
- return F.gelu
- else:
- raise NotImplementedError
-
-class MultiScaleRetention(nn.Module):
- def __init__(
- self,
- args,
- embed_dim,
- value_dim,
- num_heads,
- gate_fn="swish",
- ):
- super().__init__()
- self.args = args
- self.embed_dim = embed_dim
- self.value_dim = value_dim
- self.num_heads = num_heads
- self.head_dim = self.value_dim // num_heads
- self.key_dim = self.embed_dim // num_heads
- self.scaling = self.key_dim ** -0.5
-
- self.gate_fn = get_activation_fn(activation=str(gate_fn))
-
- self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
- self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
- self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
- self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
-
- self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
-
- self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
- self.reset_parameters()
-
- def reset_parameters(self):
- nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
- nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
- nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
- nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
- nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -1)
-
- def parallel_forward(self, qr, kr, v, mask):
- bsz, tgt_len, embed_dim = v.size()
-
- vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
-
- qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
- qk_mat = qk_mat * mask
- # invariant after normalization
- qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
- output = torch.matmul(qk_mat, vr)
- output = output.transpose(1, 2)
- return output
-
- def recurrent_forward(
- self,
- qr, kr, v,
- decay,
- incremental_state
- ):
- bsz = v.size(0)
-
- v = v.view(bsz, self.num_heads, self.head_dim, 1)
- kv = kr * v
- if "prev_key_value" in incremental_state:
- prev_kv = incremental_state["prev_key_value"]
- prev_scale = incremental_state["scale"]
- scale = prev_scale * decay + 1
- kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
- # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
- else:
- scale = torch.ones_like(decay)
-
- incremental_state["prev_key_value"] = kv
- incremental_state["scale"] = scale
-
- output = torch.sum(qr * kv, dim=3)
- return output
-
- def chunk_recurrent_forward(
- self,
- qr, kr, v,
- inner_mask
- ):
- mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
- bsz, tgt_len, embed_dim = v.size()
- chunk_len = mask.size(1)
- num_chunks = tgt_len // chunk_len
-
- assert tgt_len % chunk_len == 0
-
- qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
- kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
- v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
-
- kr_t = kr.transpose(-1, -2)
-
- qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
- qk_mat = qk_mat * mask
- inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
- qk_mat = qk_mat / inner_scale
- inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
-
- # reduce kv in one chunk
- kv = kr_t @ (v * value_inner_decay)
-
- kv_recurrent = []
- cross_scale = []
- kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
- kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
-
- # accumulate kv by loop
- for i in range(num_chunks):
- kv_recurrent.append(kv_state / kv_scale)
- cross_scale.append(kv_scale)
- kv_state = kv_state * cross_decay + kv[:, i]
- kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
-
- kv_recurrent = torch.stack(kv_recurrent, dim=1)
- cross_scale = torch.stack(cross_scale, dim=1)
-
- all_scale = torch.maximum(inner_scale, cross_scale)
- align_inner_scale = all_scale / inner_scale
- align_cross_scale = all_scale / cross_scale
-
- cross_output = (qr * query_inner_decay) @ kv_recurrent
- output = inner_output / align_inner_scale + cross_output / align_cross_scale
- # output = inner_output / cross_scale + cross_output / inner_scale
-
- output = output.transpose(2, 3)
- return output
-
- def forward(
- self,
- x,
- rel_pos,
- chunkwise_recurrent=False,
- incremental_state=None
- ):
- bsz, tgt_len, _ = x.size()
- (sin, cos), inner_mask = rel_pos
-
- q = self.q_proj(x)
- k = self.k_proj(x)
- v = self.v_proj(x)
- g = self.g_proj(x)
-
- k *= self.scaling
- q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
- k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
-
- qr = theta_shift(q, sin, cos)
- kr = theta_shift(k, sin, cos)
-
- if incremental_state is not None:
- output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
- elif chunkwise_recurrent:
- output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
- else:
- output = self.parallel_forward(qr, kr, v, inner_mask)
-
- output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
-
- output = self.gate_fn(g) * output
-
- output = self.out_proj(output)
-
- return output
-
-
diff --git a/code/xtuner/model/torchscale/component/multiway_network.py b/code/xtuner/model/torchscale/component/multiway_network.py
deleted file mode 100644
index a44a699e9e3b4d2aa9389d5b1e572158b792045d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/multiway_network.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import copy
-
-import torch
-import torch.nn as nn
-
-
-def MultiwayWrapper(args, module, dim=1):
- if args.multiway:
- return MultiwayNetwork(module, dim=dim)
- return module
-
-
-def set_split_position(position):
- def apply_fn(module):
- if hasattr(module, "split_position"):
- module.split_position = position
-
- return apply_fn
-
-
-class MultiwayNetwork(nn.Module):
- def __init__(self, module, dim=1):
- super().__init__()
- self.dim = dim
- self.A = module
- self.B = copy.deepcopy(module)
- self.B.reset_parameters()
- self.split_position = -1
-
- def forward(self, x, **kwargs):
- if self.split_position == -1:
- return self.A(x, **kwargs)
- if self.split_position == 0:
- return self.B(x, **kwargs)
- x1, x2 = torch.split(
- x,
- [self.split_position, x.size(self.dim) - self.split_position],
- dim=self.dim,
- )
- # x1, x2 = x[:self.split_position], x[self.split_position:]
- y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
- return torch.cat([y1, y2], dim=self.dim)
-
-
-class MutliwayEmbedding(MultiwayNetwork):
- def __init__(self, modules, dim=1):
- super(MultiwayNetwork, self).__init__()
- self.dim = dim
- assert len(modules) == 2
- self.A = modules[0]
- self.B = modules[1]
- self.split_position = -1
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/component/relative_position_bias.py b/code/xtuner/model/torchscale/component/relative_position_bias.py
deleted file mode 100644
index e9686f0e2e3d107b25b2e04c0df5389c91ca10bb..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/relative_position_bias.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import math
-
-import torch
-import torch.nn as nn
-
-
-class RelativePositionBias(nn.Module):
- def __init__(
- self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
- ):
- super().__init__()
- self.bidirectional = bidirectional
- self.num_buckets = num_buckets
- self.max_distance = max_distance
- self.n_heads = n_heads
- self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
-
- @staticmethod
- def _relative_position_bucket(
- relative_position, bidirectional=True, num_buckets=32, max_distance=128
- ):
- ret = 0
- n = -relative_position
- if bidirectional:
- num_buckets //= 2
- ret += (n < 0).to(torch.long) * num_buckets
- n = torch.abs(n)
- else:
- n = torch.max(n, torch.zeros_like(n))
-
- max_exact = num_buckets // 2
- is_small = n < max_exact
-
- val_if_large = max_exact + (
- torch.log(n.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- val_if_large = torch.min(
- val_if_large, torch.full_like(val_if_large, num_buckets - 1)
- )
-
- ret += torch.where(is_small, n, val_if_large)
- return ret
-
- def compute_bias(self, qlen, klen, step=None):
- step = 0 if step is None else step
- context_position = torch.arange(
- step,
- step + qlen,
- dtype=torch.long,
- device=self.relative_attention_bias.weight.device,
- )[:, None]
- memory_position = torch.arange(
- klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
- )[None, :]
- relative_position = memory_position - context_position # shape (qlen, klen)
-
- rp_bucket = self._relative_position_bucket(
- relative_position, # shape (qlen, klen)
- bidirectional=self.bidirectional,
- num_buckets=self.num_buckets,
- max_distance=self.max_distance,
- )
- rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
- values = self.relative_attention_bias(
- rp_bucket
- ) # shape (qlen, klen, num_heads)
- values = values.permute([2, 0, 1]).unsqueeze(
- 0
- ) # shape (1, num_heads, qlen, klen)
- return values
-
- def forward(self, batch_size, qlen, klen, step=None):
- # shape (batch * num_heads, qlen, klen)
- return (
- self.compute_bias(qlen, klen, step)
- .repeat(batch_size, 1, 1, 1)
- .view(-1, qlen, klen)
- )
diff --git a/code/xtuner/model/torchscale/component/rms_norm.py b/code/xtuner/model/torchscale/component/rms_norm.py
deleted file mode 100644
index 465536cbb5a98b91d876b25aae44eeb659d8e84c..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/rms_norm.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.nn as nn
-
-class RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
- super().__init__()
- self.eps = eps
- self.elementwise_affine = elementwise_affine
- if self.elementwise_affine:
- self.weight = nn.Parameter(torch.ones(dim))
- else:
- self.register_parameter('weight', None)
-
- def _norm(self, x):
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-
- def forward(self, x):
- output = self._norm(x.float()).type_as(x)
- if self.weight is not None:
- output = output * self.weight
- return output
-
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/component/utils.py b/code/xtuner/model/torchscale/component/utils.py
deleted file mode 100644
index 4c8a5ad4441623f96a6fcc93a211c132a8f00536..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/utils.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.distributed as dist
-
-def padding_to_multiple_of(n, mult):
- remainder = n % mult
- if remainder == 0:
- return 0
- return mult - remainder
-
-def get_data_parallel_group():
- if torch.distributed.is_initialized():
- if not hasattr(get_data_parallel_group, "_global_group"):
- get_data_parallel_group._global_group = dist.new_group()
- return get_data_parallel_group._global_group
- else:
- return None
-
-def get_rank(group):
- return dist.get_rank(group=group)
-
-def get_world_size(group):
- if torch.distributed.is_initialized():
- return dist.get_world_size(group=group)
- else:
- return 1
-
-def get_data_parallel_rank():
- return get_rank(get_data_parallel_group())
-
-def get_data_parallel_world_size():
- return get_world_size(get_data_parallel_group())
-
-
-class Allgather(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, input_):
- world_size = get_data_parallel_world_size()
- dim_size = list(input_.size())
- dim_size[0] = dim_size[0] * world_size
-
- output = torch.empty(dim_size, dtype=input_.dtype,
- device=torch.cuda.current_device())
- torch.distributed._all_gather_base(output, input_.contiguous(),
- group=get_data_parallel_group())
-
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- world_size = get_data_parallel_world_size()
-
- dim_size = list(grad_output.size())
- assert dim_size[0] % world_size == 0, \
- "First dimension of the tensor should be divisible by tensor parallel size"
-
- dim_size[0] = dim_size[0] // world_size
-
- output = torch.empty(dim_size, dtype=grad_output.dtype,
- device=torch.cuda.current_device())
-
- torch.distributed._reduce_scatter_base(output, grad_output.contiguous(),
- group=get_data_parallel_group())
-
- return output
-
-all_gather_func = Allgather.apply
diff --git a/code/xtuner/model/torchscale/component/xmoe/__init__.py b/code/xtuner/model/torchscale/component/xmoe/__init__.py
deleted file mode 100644
index 3ae31e2507e8759f2ac7f85e517288f536c04ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/xmoe/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
diff --git a/code/xtuner/model/torchscale/component/xmoe/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/component/xmoe/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 1c59ba2af8d6cc840e3764f38a4fd7ce760804d1..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/xmoe/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/xmoe/__pycache__/global_groups.cpython-311.pyc b/code/xtuner/model/torchscale/component/xmoe/__pycache__/global_groups.cpython-311.pyc
deleted file mode 100644
index d0a7e2a5ca94c0afe44d6ec5fb6b842e4f348e6b..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/xmoe/__pycache__/global_groups.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/xmoe/__pycache__/moe_layer.cpython-311.pyc b/code/xtuner/model/torchscale/component/xmoe/__pycache__/moe_layer.cpython-311.pyc
deleted file mode 100644
index 4e7bf3927d7ce0670dc72c779118dbdd4590a417..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/xmoe/__pycache__/moe_layer.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/xmoe/__pycache__/routing.cpython-311.pyc b/code/xtuner/model/torchscale/component/xmoe/__pycache__/routing.cpython-311.pyc
deleted file mode 100644
index 9e80b3d9eecd65305005732ec60edc80ffb41977..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/component/xmoe/__pycache__/routing.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/component/xmoe/global_groups.py b/code/xtuner/model/torchscale/component/xmoe/global_groups.py
deleted file mode 100644
index c6c31096b9eee3c48025f63e21419a28ceb759c1..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/xmoe/global_groups.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import torch.distributed as dist
-
-
-def _find_my_group_index(grouped_ranks):
- my_rank = dist.get_rank()
- for i, group in enumerate(grouped_ranks):
- if my_rank in group:
- return i
- raise RuntimeError
-
-def get_moe_group(moe_expert_count=None):
- if dist.is_initialized():
- if not hasattr(get_moe_group, "_moe_groups"):
- world_size = dist.get_world_size()
-
- if world_size <= moe_expert_count:
- assert moe_expert_count % world_size == 0
- moe_groups = [[i] for i in range(world_size)]
-
- else:
- assert world_size % moe_expert_count == 0
- ranks_per_group = world_size // moe_expert_count
- moe_groups = [
- [i + j * moe_expert_count for j in range(ranks_per_group)]
- for i in range(moe_expert_count)
- ]
-
- get_moe_group._moe_expert_count = moe_expert_count
- get_moe_group._moe_group_idx = moe_groups
- get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
-
- my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
- return my_group_idx, get_moe_group._moe_groups[my_group_idx]
-
-
-def get_all2all_group(moe_expert_count):
- if dist.is_initialized():
- if not hasattr(get_all2all_group, "_all2all_groups"):
- world_size = dist.get_world_size()
-
- # more experts than world size
- if world_size <= moe_expert_count:
- assert moe_expert_count % world_size == 0
- all2all_groups = [[i for i in range(world_size)]]
-
- # larger world than num experts
- else:
- assert world_size % moe_expert_count == 0
- ranks_per_group = world_size // moe_expert_count
- all2all_groups = [
- [i * moe_expert_count + j for j in range(moe_expert_count)]
- for i in range(ranks_per_group)
- ]
-
- get_all2all_group._all2all_group_idx = all2all_groups
- get_all2all_group._all2all_groups = [
- dist.new_group(g) for g in all2all_groups
- ]
-
- my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
- return get_all2all_group._all2all_groups[my_group_idx]
diff --git a/code/xtuner/model/torchscale/component/xmoe/moe_layer.py b/code/xtuner/model/torchscale/component/xmoe/moe_layer.py
deleted file mode 100644
index 51e77137e2fc56ffb1e3a37494857542b079f3f7..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/xmoe/moe_layer.py
+++ /dev/null
@@ -1,307 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
-#
-# This source code is licensed under the BSD license found in the
-# LICENSE file in the root directory of this source tree.
-
-# NOTE: This is a mirror of the code in
-# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
-
-import logging
-import time
-from typing import Any, Tuple, cast
-
-import torch
-import torch.distributed as dist
-from torch import Tensor
-from torch.nn import Module, ModuleList
-
-from .global_groups import get_all2all_group, get_moe_group
-
-try:
- from fairseq.modules.moe import MOELayer
-
- has_fairseq = True
- Base = MOELayer
-except ModuleNotFoundError:
- Base = Module
- has_fairseq = False
-
-try:
- # To enable Tutel MoE optimizations:
- # python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
- from tutel import moe as tutel_moe
-
- has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one
-except ModuleNotFoundError:
- has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1
-
-logger = logging.getLogger(__name__)
-
-
-# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
-# See https://arxiv.org/pdf/2006.16668.pdf for details.
-
-# Based on https://github.com/pytorch/pytorch/pull/40762
-class _AllToAll(torch.autograd.Function):
- @staticmethod
- def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
- ctx.group = group
- input = input.contiguous()
- output = torch.empty_like(input)
- if torch.distributed.is_initialized():
- dist.all_to_all_single(output, input, group=group)
- else:
- assert group is None
- output = input
- return output
-
- @staticmethod
- def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
- return (None, _AllToAll.apply(ctx.group, *grad_output))
-
-
-
-
-class MOELayer(Base):
- """MOELayer module which implements MixtureOfExperts as described in Gshard_.
- ::
-
- gate = Top2Gate(model_dim, num_experts)
- moe = MOELayer(gate, expert)
- output = moe(input)
- l_aux = moe.l_aux
-
- .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
-
- Args:
- gate (torch.nn.Module):
- gate network
- expert (torch.nn.Module):
- expert network
- """
-
- def __init__(self, gate, experts, args):
- if has_fairseq:
- super(Base, self).__init__()
- else:
- super().__init__()
- self.gate = gate
- if type(experts) == ModuleList:
- self.experts = cast(ModuleList, experts)
- else:
- self.experts = ModuleList([experts])
- _, self.expert_group = get_moe_group(args.moe_expert_count)
- self.all2all_group = get_all2all_group(args.moe_expert_count)
- self.world_size = dist.get_world_size(group=self.expert_group)
- self.all2all_size = dist.get_world_size(group=self.all2all_group)
- for p in experts.parameters():
- p.expert = True # type: ignore
- self.num_local_experts = len(self.experts)
- self.args = args
- self.in_generation = False
- self.a2a_cuda_event_intervals = []
- self.a2a_cpu_time_ms = 0.0
-
- def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
- assert len(input) == 1, "only single input Tensor supported"
- input = input[0]
- assert (
- len(input.shape) == 3
- ), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
- if input_padding_mask is not None:
- assert (
- len(input_padding_mask.shape) == 2
- ), "input Tensor must have dimensions: (s)equence, (t)oken"
- assert input_padding_mask.shape[0] == input.shape[0]
- assert input_padding_mask.shape[1] == input.shape[1]
- # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
-
- # Implement Algorithm 2 from GShard paper.
- d_model = input.shape[2]
- # Pad to expected batch size
- input_shape = list(input.shape)
- expected_bsz = (
- getattr(self.args, "batch_size", 0)
- if self.training
- else getattr(self.args, "batch_size_valid", 0)
- )
- # This indicates that --batch-size or --max-sentences is not specified
- if expected_bsz is None:
- expected_bsz = 0
- # Note: Padding is not necessary at generation time at present
- # because all DDP workers process the same batch. Also, batch size at generation time
- # can be different from that present in the checkpoint state
- if (
- not self.in_generation
- and expected_bsz != 0
- and input_shape[0] != expected_bsz
- ):
- logger.warning(
- f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
- )
- assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
- padded_input = torch.zeros(
- (expected_bsz, input_shape[1], input_shape[2]),
- dtype=input.dtype,
- layout=input.layout,
- device=input.device,
- )
- padded_input[: input_shape[0], :, :] = input
- input = padded_input
-
- padded_input_padding_mask = torch.ones(
- (
- expected_bsz,
- input_shape[1],
- ),
- dtype=torch.bool,
- device=input.device,
- )
- if input_padding_mask is not None:
- padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
- else:
- padded_input_padding_mask[: input_shape[0], :] = False
- input_padding_mask = padded_input_padding_mask
-
- # Reshape into S tokens by dropping sequence dimension.
- reshaped_input = input.reshape(-1, d_model)
- reshaped_input_shape = reshaped_input.shape
- reshaped_input_padding_mask = (
- input_padding_mask.reshape(-1) if input_padding_mask is not None else None
- )
-
- # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
- # Pro of --max-tokens: more flexible for MT variable sequence lengths
- # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
- if expected_bsz == 0:
- expected_dim = reshaped_input_shape[0] * torch.ones(
- (1,), dtype=torch.long, device=input.device
- )
- dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
- expected_dim = int(expected_dim.item())
- padded_input = torch.zeros(
- (expected_dim, reshaped_input_shape[1]),
- dtype=input.dtype,
- layout=input.layout,
- device=input.device,
- )
- padded_input[: reshaped_input_shape[0], :] = reshaped_input
- reshaped_input = padded_input
-
- padded_input_padding_mask = torch.ones(
- (expected_dim,), dtype=torch.bool, device=padded_input.device
- )
- if reshaped_input_padding_mask is not None:
- padded_input_padding_mask[
- : reshaped_input_shape[0]
- ] = reshaped_input_padding_mask
- else:
- padded_input_padding_mask[: reshaped_input_shape[0]] = False
- reshaped_input_padding_mask = padded_input_padding_mask
-
- if has_tutel:
- l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
- reshaped_input, reshaped_input_padding_mask
- )
- S, M = reshaped_input.size(0), reshaped_input.size(1)
-
- if not hasattr(self, "_tutel_dispatcher"):
- self._tutel_dispatcher = tutel_moe.fast_dispatcher(
- E, C, M, dispatch_dtype=reshaped_input.dtype
- )
- self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
- dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
- else:
- l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
- reshaped_input, reshaped_input_padding_mask
- )
-
- dispatch_mask = dispatch_mask.to(input.dtype).permute(
- 1, 2, 0
- ) # S,E,C -> E,C,S
- E, C, S = dispatch_mask.size()
- M = reshaped_input.size(1)
- assert reshaped_input.size() == (S, M)
- # einsum("sec,sm->ecm")
- dispatched_input = torch.mm(
- dispatch_mask.view(E * C, S), reshaped_input
- ) # -> (E*C),M
-
- if self.all2all_size > 1:
- dispatched_input = self.all_to_all_wrapper(dispatched_input)
-
- # Re-shape after all-to-all: ecm -> gecm
- dispatched_input = dispatched_input.reshape(
- self.all2all_size, self.num_local_experts, -1, d_model
- )
- chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
- expert_outputs = []
- for chunk, expert in zip(chunks, self.experts):
- expert_outputs += [expert(chunk)]
- expert_output = torch.cat(expert_outputs, dim=1)
-
- if self.all2all_size > 1:
- expert_output = self.all_to_all_wrapper(expert_output)
-
- # Re-shape back: gecm -> ecm
- expert_output = expert_output.reshape(
- self.all2all_size * self.num_local_experts, -1, d_model
- )
-
- if has_tutel:
- combined_output = self._tutel_dispatcher.decode(
- expert_output.view(E * C, M)
- )
- else:
- # einsum("sec,ecm->sm")
- combined_output = combine_weights.view(S, E * C).mm(
- expert_output.view(E * C, M)
- )
-
- # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
- combined_output = combined_output[: reshaped_input_shape[0], :]
- combined_output = combined_output.reshape(input.shape)
- combined_output = combined_output[: input_shape[0], :, :]
-
- self.record_all_to_all_stats()
-
- return combined_output, l_aux
-
- def prepare_for_inference_(self):
- self.in_generation = True
-
- def all_to_all_wrapper(self, input: Tensor):
- dummy_a2a = getattr(self.args, "dummy_a2a", False)
- if dummy_a2a:
- input = input.contiguous()
- output = input.detach().clone()
- return input
- # always record times, since it is not a lot of overhead
- # if we do not log it we simply clear it off in record_all_to_all_stats
- cuda_start = torch.cuda.Event(enable_timing=True)
- cuda_end = torch.cuda.Event(enable_timing=True)
- cpu_start = time.time() * 1000
- cuda_start.record()
- output = _AllToAll.apply(self.all2all_group, input)
- cuda_end.record()
- cpu_end = time.time() * 1000
- self.a2a_cpu_time_ms += cpu_end - cpu_start
- self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
- return output
-
- def record_all_to_all_stats(self):
- # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
- record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
- if record_a2a_perf_stats:
- torch.cuda.synchronize()
- self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
- a2a_cuda_time_ms = 0.0
- for ev_start, ev_end in self.a2a_cuda_event_intervals:
- a2a_cuda_time_ms += ev_start.elapsed_time(ev_end)
- self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms
- # reset stats
- self.a2a_cpu_time_ms = 0.0
- self.a2a_cuda_event_intervals = []
diff --git a/code/xtuner/model/torchscale/component/xmoe/routing.py b/code/xtuner/model/torchscale/component/xmoe/routing.py
deleted file mode 100644
index 751a76a16195989a38a1f2289b8d19033f213ff8..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/xmoe/routing.py
+++ /dev/null
@@ -1,525 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
-#
-# This source code is licensed under the BSD license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
-# Code is inspired by Top2GatingOnLogits from lingvo:
-# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
-
-# NOTE: This is a mirror of the code in
-# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
-
-import math
-from typing import Callable, Dict, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from torch import Tensor
-
-from .moe_layer import fused_cumsum_sub_one, has_tutel
-
-# use a fixed temperature to compute balance loss
-TEMPERATURE_FOR_L_UAX = 0.07
-
-# maximum capacity of 1 expert as a fraction of number of tokens in the batch
-# Note: setting this to 1.0 causes inference to significantly slow down
-EVAL_CAPACITY_TOKEN_FRACTION = 0.25
-
-# logging
-SAMPLE_FRACTION = 0.2
-
-
-def top1gating(
- logits: torch.Tensor,
- input_mask: Optional[torch.Tensor] = None,
- use_fp32=False,
- capacity_factor=1.0,
- eval_mode=False,
- moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
- use_xmoe=False,
- gate_obj=None,
-) -> Tuple[Tensor, Tensor, Tensor, Dict]:
- """Implements Top2Gating on logits."""
- metadata = {}
- if use_fp32:
- orig_dtype = logits.dtype
- logits = logits.float()
-
- gates = F.softmax(logits, dim=1)
- metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
-
- # gates has shape of SE
- num_tokens = gates.shape[0]
- num_experts = gates.shape[1]
- if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
- capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
- else:
- # capacity = capacity_factor * S/E
- capacity = int(capacity_factor * math.ceil(num_tokens / num_experts))
-
- # Create a mask for 1st's expert per token
- indices1_s = torch.argmax(gates, dim=1)
- mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
- if input_mask is not None and input_mask.any():
- nonpadding = ~input_mask
- mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
-
- # for logging (percent of tokens routed to each expert)
- expert1_hist = (
- 100
- * torch.histc(
- (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
- )
- / num_tokens
- )
- metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
- expert1_hist = (
- torch.sort(expert1_hist, dim=0, descending=True).values
- + torch.finfo(torch.float32).tiny
- )
-
- sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
- metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
- metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
-
- gates1_s = (gates * mask1).sum(dim=1)
-
- # Compute locations in capacity buffer
- locations1 = fused_cumsum_sub_one(mask1)
-
- # Compute l_aux
- me = torch.mean(gates, dim=0)
- ce = torch.mean(mask1.to(gates.dtype), dim=0)
-
- l_aux = torch.mean(me * ce)
- l_aux = l_aux * num_experts * num_experts
-
- if has_tutel:
- locations1_s = torch.sum(locations1 * mask1, dim=1)
- return (
- l_aux,
- metadata,
- capacity,
- num_experts,
- [
- indices1_s,
- ],
- [
- locations1_s,
- ],
- [
- gates1_s,
- ],
- )
-
- # Remove locations outside capacity from mask
- mask1 = mask1 * torch.lt(locations1, capacity)
- # Store the capacity location for each token
- locations1_s = torch.sum(locations1 * mask1, dim=1)
-
- # Calculate combine_weights and dispatch_mask
- gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
- # locations1_sc = num_tokens * capacity
- locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
- combine1_sec = torch.bmm(
- # einsum("se,sc->sec")
- gates1.unsqueeze(-1),
- locations1_sc.to(gates1.dtype).unsqueeze(1),
- )
- dispatch_mask = combine1_sec.bool()
- if use_fp32:
- return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata
- else:
- return l_aux, combine1_sec, dispatch_mask, metadata
-
-
-class Top1Gate(torch.nn.Module):
- """Gate module which implements Top2Gating as described in Gshard_.
- ::
-
- gate = Top2Gate(model_dim, num_experts)
- l_aux, combine_weights, dispatch_mask = gate(input)
-
- .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
-
- Args:
- model_dim (int):
- size of model embedding dimension
- num_experts (ints):
- number of experts in model
- """
-
- wg: torch.nn.Linear
-
- def __init__(
- self,
- model_dim: int,
- num_experts: int,
- use_fp32=False,
- input_noise_type=None,
- capacity_factor=1.0,
- moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
- use_xmoe=False,
- ) -> None:
- # TODO: merge this to top2gate.py
- #
- super().__init__()
-
- if not use_xmoe:
- self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
- else:
- self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
- wg = torch.empty(num_experts, 16)
- torch.nn.init.orthogonal_(wg, gain=0.32)
- self.register_parameter("wg", torch.nn.Parameter(wg))
-
- self.use_xmoe = use_xmoe
- self.use_fp32 = use_fp32
- self.input_noise_type = input_noise_type
- self.capacity_factor = capacity_factor
- self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
-
- def forward(self, input, mask=None): # type: ignore
- if self.use_xmoe:
- input = self.wg_reduction(input)
- with torch.no_grad():
- wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
- self.wg.mul_(1.5 / wg_norm)
- logits = self._cosine(input, self.wg)
- logits = self._make_finite(logits)
- else:
- logits = self.wg(input)
-
- return top1gating(
- logits,
- mask,
- use_fp32=self.use_fp32,
- capacity_factor=self.capacity_factor,
- eval_mode=not self.training,
- moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
- use_xmoe=self.use_xmoe,
- gate_obj=self,
- )
-
- def _make_finite(self, scores):
- ok = scores.isfinite()
- if not ok.all():
- # NaNs here can break the assignment algorithm
- scores[~ok] = scores[ok].min()
- return scores
-
- def _get_gating_temperature(self, eps=1e-4):
- if self.gating_t.data.item() < eps:
- return eps
- return self.gating_t
-
- def _cosine(self, mat1, mat2, eps=1e-4):
- assert mat1.dim() == 2
- assert mat2.dim() == 2
- # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
- mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
- return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
-
-
-gumbel_map: Dict[torch.device, Callable] = {}
-
-
-def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
- gumbel = gumbel_map.get(device)
- if gumbel is None:
- one = torch.tensor(1.0, device=device)
- zero = torch.tensor(0.0, device=device)
- gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
- gumbel_map[device] = gumbel
- return gumbel(shape)
-
-
-def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor:
- if unsqueeze_indices:
- indices = indices.unsqueeze(-1)
- assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
- output = torch.zeros(
- indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
- )
- output.scatter_(len(output.shape) - 1, indices, 1)
- return output
-
-
-def entropy(probs):
- logits = torch.distributions.utils.probs_to_logits(probs)
- p_log_p = probs * logits
- return -p_log_p.sum(-1)
-
-
-def top2gating(
- logits: torch.Tensor,
- input_mask: Optional[torch.Tensor] = None,
- use_fp32=False,
- second_expert_policy="sampling",
- normalize_gate_prob_before_dropping=False,
- eval_mode=False,
- moe_eval_capacity_token_fraction=0.25,
- batch_prioritized_routing=False,
-) -> Tuple[Tensor, Tensor, Tensor]:
- """Implements Top2Gating on logits."""
- metadata = {}
- if use_fp32:
- orig_dtype = logits.dtype
- logits = logits.float()
- gates = F.softmax(logits, dim=1)
- metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
- # gates has shape of SE
- num_tokens = gates.shape[0]
- num_experts = gates.shape[1]
- if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
- capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
- else:
- # capacity = 2S/E
- capacity = 2 * math.ceil(num_tokens / num_experts)
-
- # Create a mask for 1st's expert per token
- indices1_s = torch.argmax(gates, dim=1, keepdim=True)
- mask1 = one_hot(indices1_s, num_experts)
- if second_expert_policy == "sampling":
- # Create a mask for 2nd's expert per token using Gumbel-max trick
- # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
- logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
- else:
- logits_w_noise = logits
- # Replace top-expert with min value
- logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
- indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True)
- mask2 = one_hot(indices2_s, num_experts)
- gates1_s = (gates * mask1).sum(dim=1)
- gates2_s = (gates * mask2).sum(dim=1)
-
- if normalize_gate_prob_before_dropping:
- # Normalize gate probabilities
- denom_s = gates1_s + gates2_s
- # Avoid divide-by-zero
- denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
- gates1_s = gates1_s / denom_s
- gates2_s = gates2_s / denom_s
-
- if second_expert_policy == "random":
- sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
- mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
-
- # Compute locations in capacity buffer
- if input_mask is not None and input_mask.any():
- nonpadding = ~input_mask
- mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
- mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
-
- if batch_prioritized_routing:
- # if batch_prioritized_routing:
- importance_scores = -1 * gates.max(dim=1)[0]
- sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
- sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
- importance_sorted_locations1 = sorted_cumsum1[
- importance_scores.argsort(dim=0).argsort(dim=0)
- ]
-
- sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
- sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
- importance_sorted_locations2 = sorted_cumsum2[
- importance_scores.argsort(dim=0).argsort(dim=0)
- ]
-
- importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
-
- locations1, locations2 = (
- importance_sorted_locations1,
- importance_sorted_locations2,
- )
- else:
- locations1 = fused_cumsum_sub_one(mask1)
- locations2 = fused_cumsum_sub_one(mask2)
- # Update 2nd's location by accounting for locations of 1st
- locations2 += torch.sum(mask1, dim=0, keepdim=True)
-
- # Compute l_aux
- me = torch.mean(gates, dim=0)
- ce = torch.mean(mask1.to(gates.dtype), dim=0)
- l_aux = torch.mean(me * ce)
- l_aux = l_aux * num_experts * num_experts
-
- # for logging purposes
- metadata["overflow_expert1"] = (
- 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
- )
- metadata["overflow_expert2"] = (
- 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
- )
-
- # Remove locations outside capacity from mask
- mask1_, mask2_ = mask1, mask2
- mask1 = mask1 * torch.lt(locations1, capacity)
- mask2 = mask2 * torch.lt(locations2, capacity)
-
- # for logging (percent of tokens routed to each expert)
- expert1_hist = (
- 100
- * torch.histc(
- (indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
- )
- / num_tokens
- )
- metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
- expert1_hist = (
- torch.sort(expert1_hist, dim=0, descending=True).values
- + torch.finfo(torch.float32).tiny
- )
-
- expert2_hist = (
- 100
- * torch.histc(
- (indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
- )
- / num_tokens
- )
- metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
- expert2_hist = (
- torch.sort(expert2_hist, dim=0, descending=True).values
- + torch.finfo(torch.float32).tiny
- )
-
- sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
- metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
- metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
-
- metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum()
- metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum()
-
- if not normalize_gate_prob_before_dropping:
- # Normalize gate probabilities
- gates1_s = (gates * mask1).sum(dim=1)
- gates2_s = (gates * mask2).sum(dim=1)
- denom_s = gates1_s + gates2_s
- # Avoid divide-by-zero
- denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
- gates1_s /= denom_s
- gates2_s /= denom_s
-
- if has_tutel:
- locations1_s = torch.sum(locations1 * mask1_, dim=1)
- locations2_s = torch.sum(locations2 * mask2_, dim=1)
- return (
- l_aux,
- metadata,
- capacity,
- num_experts,
- [indices1_s, indices2_s],
- [locations1_s, locations2_s],
- [gates1_s, gates2_s],
- )
-
- # Store the capacity location for each token
- locations1_s = torch.sum(locations1 * mask1, dim=1)
- locations2_s = torch.sum(locations2 * mask2, dim=1)
-
- # Calculate combine_weights and dispatch_mask
- gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
- gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se")
- locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
- locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
- combine1_sec = torch.bmm(
- # einsum("se,sc->sec")
- gates1.unsqueeze(-1),
- locations1_sc.to(gates1.dtype).unsqueeze(1),
- )
- combine2_sec = torch.bmm(
- # einsum("se,sc->sec")
- gates2.unsqueeze(-1),
- locations2_sc.to(gates2.dtype).unsqueeze(1),
- )
- combine_weights = combine1_sec + combine2_sec
- dispatch_mask = combine_weights.bool()
- if use_fp32:
- return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata
- else:
- return l_aux, combine_weights, dispatch_mask, metadata
-
-
-class Top2Gate(torch.nn.Module):
- """Gate module which implements Top2Gating as described in Gshard_.
- ::
-
- gate = Top2Gate(model_dim, num_experts)
- l_aux, combine_weights, dispatch_mask = gate(input)
-
- .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
-
- Args:
- model_dim (int):
- size of model embedding dimension
- num_experts (ints):
- number of experts in model
- """
-
- wg: torch.nn.Linear
-
- def __init__(
- self,
- model_dim: int,
- num_experts: int,
- use_fp32=False,
- second_expert_policy="sampling",
- normalize_gate_prob_before_dropping=False,
- moe_eval_capacity_token_fraction=0.25,
- batch_prioritized_routing=False,
- use_xmoe=False,
- ) -> None:
- super().__init__()
- if not use_xmoe:
- self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
- else:
- self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
- wg = torch.empty(num_experts, 16)
- torch.nn.init.orthogonal_(wg, gain=0.32)
- self.register_parameter("wg", torch.nn.Parameter(wg))
- self.use_fp32 = use_fp32
- self.second_expert_policy = second_expert_policy
- self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping
- self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
- self.batch_prioritized_routing = batch_prioritized_routing
- self.use_xmoe = use_xmoe
-
- def forward(self, input, mask=None): # type: ignore
- if self.use_xmoe:
- input = self.wg_reduction(input)
- with torch.no_grad():
- wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
- self.wg.mul_(1.5 / wg_norm)
- logits = self._cosine(input, self.wg)
- logits = self._make_finite(logits)
- else:
- logits = self.wg(input)
- return top2gating(
- logits,
- mask,
- use_fp32=self.use_fp32,
- second_expert_policy=self.second_expert_policy,
- normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping,
- eval_mode=not self.training,
- moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
- batch_prioritized_routing=self.batch_prioritized_routing,
- )
-
- def _cosine(self, mat1, mat2, eps=1e-4):
- assert mat1.dim() == 2
- assert mat2.dim() == 2
- # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
- mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
- return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
-
- def _make_finite(self, scores):
- ok = scores.isfinite()
- if not ok.all():
- # NaNs here can break the assignment algorithm
- scores[~ok] = scores[ok].min()
- return scores
diff --git a/code/xtuner/model/torchscale/component/xpos_relative_position.py b/code/xtuner/model/torchscale/component/xpos_relative_position.py
deleted file mode 100644
index a3ec129d9c075e1054d5b4c8a99baee4812feadf..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/component/xpos_relative_position.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-def fixed_pos_embedding(x):
- seq_len, dim = x.shape
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
- sinusoid_inp = (
- torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
- )
- return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
-
-def rotate_every_two(x):
- x1 = x[:, :, ::2]
- x2 = x[:, :, 1::2]
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
-
-def duplicate_interleave(m):
- """
- A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
- """
- dim0 = m.shape[0]
- m = m.view(-1, 1) # flatten the matrix
- m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
- m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
- return m
-
-def apply_rotary_pos_emb(x, sin, cos, scale=1):
- sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
- # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
- return (x * cos) + (rotate_every_two(x) * sin)
-
-
-class XPOS(nn.Module):
- def __init__(
- self, head_dim, scale_base=512
- ):
- super().__init__()
- self.head_dim = head_dim
- self.scale_base = scale_base
- self.register_buffer(
- "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim)
- )
-
- def forward(self, x, offset=0, downscale=False):
- length = x.shape[1]
- min_pos = -(length + offset) // 2
- max_pos = length + offset + min_pos
- scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None]
- sin, cos = fixed_pos_embedding(scale)
-
- if scale.shape[0] > length:
- scale = scale[-length:]
- sin = sin[-length:]
- cos = cos[-length:]
-
- if downscale:
- scale = 1 / scale
-
- x = apply_rotary_pos_emb(x, sin, cos, scale)
- return x
diff --git a/code/xtuner/model/torchscale/model/BEiT3.py b/code/xtuner/model/torchscale/model/BEiT3.py
deleted file mode 100644
index 92737a21d857fdae516f72bf20f506eab8048646..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/BEiT3.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import torch
-import torch.nn as nn
-
-from torchscale.architecture.encoder import Encoder
-from torchscale.component.embedding import (
- PositionalEmbedding,
- TextEmbedding,
- VisionEmbedding,
-)
-from torchscale.component.multiway_network import MutliwayEmbedding
-
-
-class BEiT3(nn.Module):
- def __init__(self, args, **kwargs):
- super().__init__()
- self.args = args
- assert args.multiway
- assert args.vocab_size > 0
- assert not args.share_encoder_input_output_embed
- self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
- self.vision_embed = VisionEmbedding(
- args.img_size,
- args.patch_size,
- args.in_chans,
- args.encoder_embed_dim,
- contain_mask_token=True,
- prepend_cls_token=True,
- )
- # being consistent with Fairseq, which starts from 2 for position embedding
- embed_positions = MutliwayEmbedding(
- modules=[
- PositionalEmbedding(self.vision_embed.num_position_embeddings() + 2, args.encoder_embed_dim),
- PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
- ],
- dim=1,
- )
- self.encoder = Encoder(
- args,
- embed_tokens=None,
- embed_positions=embed_positions,
- output_projection=None,
- is_encoder_decoder=False,
- )
-
- def forward(
- self,
- textual_tokens=None,
- visual_tokens=None,
- text_padding_position=None,
- attn_mask=None,
- vision_masked_position=None,
- incremental_state=None,
- positions=None,
- ):
- assert textual_tokens is not None or visual_tokens is not None
-
- if textual_tokens is None:
- x = self.vision_embed(visual_tokens, vision_masked_position)
- encoder_padding_mask = None
- multiway_split_position = -1
- elif visual_tokens is None:
- x = self.text_embed(textual_tokens)
- encoder_padding_mask = text_padding_position
- multiway_split_position = 0
- else:
- x1 = self.vision_embed(visual_tokens, vision_masked_position)
- multiway_split_position = x1.size(1)
- x2 = self.text_embed(textual_tokens)
- x = torch.cat([x1, x2], dim=1)
-
- if text_padding_position is not None:
- encoder_padding_mask = torch.cat(
- [
- torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
- text_padding_position,
- ],
- dim=1,
- )
- else:
- encoder_padding_mask = None
-
- encoder_out = self.encoder(
- src_tokens=None,
- encoder_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- token_embeddings=x,
- multiway_split_position=multiway_split_position,
- incremental_state=incremental_state,
- positions=positions,
- )
- encoder_out["multiway_split_position"] = multiway_split_position
-
- return encoder_out
diff --git a/code/xtuner/model/torchscale/model/LongNet.py b/code/xtuner/model/torchscale/model/LongNet.py
deleted file mode 100644
index 33bf1763ba06033096d272f16ab406b06eb20303..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNet.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-import os
-import sys
-
-this_file_dir = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(this_file_dir, '../../'))
-
-from torchscale.model import LongNetConfig as longnet_arch
-from torchscale.architecture.config import EncoderConfig
-from torchscale.architecture.decoder import Decoder, DecoderLayer
-from torchscale.architecture.encoder import Encoder, EncoderLayer
-from torchscale.component.dilated_attention import DilatedAttention
-from fairscale.nn import checkpoint_wrapper, wrap
-
-
-class LongNetDecoderLayer(DecoderLayer):
-
- def build_self_attention(self, embed_dim, args):
- return DilatedAttention(
- args,
- embed_dim,
- args.decoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- encoder_decoder_attention=False,
- subln=args.subln,
- )
-
-class LongNetDecoder(Decoder):
-
- def build_decoder_layer(
- self, args, depth, is_moe_layer=False, is_encoder_decoder=False
- ):
- layer = LongNetDecoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
-class LongNetEncoderLayer(EncoderLayer):
-
- def build_self_attention(self, embed_dim, args):
- return DilatedAttention(
- args,
- embed_dim,
- args.encoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- encoder_decoder_attention=False,
- subln=args.subln,
- )
-
-class LongNetEncoder(Encoder):
-
- def build_encoder_layer(
- self, args, depth, is_moe_layer=False, is_encoder_decoder=False
- ):
- layer = LongNetEncoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
-def make_longnet(args):
- if args.arch in longnet_arch.__dict__.keys():
- longnet_args = longnet_arch.__dict__[args.arch]
- if hasattr(args, 'dropout'):
- longnet_args['dropout'] = args.dropout
- if hasattr(args, 'drop_path_rate'):
- longnet_args['drop_path_rate'] = args.drop_path_rate
- longnet_args = EncoderConfig(**longnet_args)
- model = LongNetEncoder(longnet_args)
- print('Number of trainable LongNet parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
- return model
-
-
-
-
-def make_longnet_from_name(config_name: str,
- dilated_ratio: str='[1, 2, 4, 8, 16]',
- segment_length: str='[1024, 2048, 4096, 8192, 16384]',
- drop_path_rate: int=0.1,
- dropout: float=0.1,
- enable_gradient_checkpoint = False,
- ):
- '''
- make LongNet model from config name
-
- Arguments:
- ----------
- config_name: str
- name of the config
- dilated_ratio: str
- dilated ratio
- segment_length: str
- segment length
- drop_path_rate: int
- drop path rate
- dropout: float
- dropout rate
- '''
- if config_name in longnet_arch.__dict__.keys():
- longnet_args = longnet_arch.__dict__[config_name]
-
- longnet_args = longnet_arch.__dict__[config_name]
-
- longnet_args['dropout'] = dropout
- longnet_args['drop_path_rate'] = drop_path_rate
-
- # set dilated ratio and segment length
- longnet_args['dilated_ratio'] = dilated_ratio
- longnet_args['segment_length'] = segment_length
- longnet_args['checkpoint_activations'] = enable_gradient_checkpoint
-
- print('dilated_ratio: ', dilated_ratio)
- print('segment_length: ', segment_length)
-
- longnet_args = EncoderConfig(**longnet_args)
- model = LongNetEncoder(longnet_args)
- print('Number of trainable LongNet parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
- return model
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/LongNetConfig.py b/code/xtuner/model/torchscale/model/LongNetConfig.py
deleted file mode 100644
index b5d33ba457fb1a7617b49a8860952691b2b613b2..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetConfig.py
+++ /dev/null
@@ -1,397 +0,0 @@
-LongNet_8_layers_256_dim_mlp2 = {
- 'encoder_layers': 8,
- 'encoder_embed_dim': 256,
- 'encoder_ffn_embed_dim': 512,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4]',
- 'segment_length': '[512, 1024, 2048]',
- 'block_shift': True,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_256_dim_mlp2 = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 256,
- 'encoder_ffn_embed_dim': 512,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4]',
- 'segment_length': '[512, 1024, 2048]',
- 'block_shift': True,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_8_layers_256_dim = {
- 'encoder_layers': 8,
- 'encoder_embed_dim': 256,
- 'encoder_ffn_embed_dim': 1024,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'block_shift': True,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_256_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 256,
- 'encoder_ffn_embed_dim': 1024,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'block_shift': True,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_3_layers_384_dim = {
- 'encoder_layers': 3,
- 'encoder_embed_dim': 384,
- 'encoder_ffn_embed_dim': 1536,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_6_layers_384_dim = {
- 'encoder_layers': 6,
- 'encoder_embed_dim': 384,
- 'encoder_ffn_embed_dim': 1536,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_384_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 384,
- 'encoder_ffn_embed_dim': 1536,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_512_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 512,
- 'encoder_ffn_embed_dim': 1024,
- 'encoder_attention_heads': 8,
- 'dilated_ratio': '[1, 2, 4]',
- 'segment_length': '[512, 1024, 2048]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_2_layers_512_dim = {
- 'encoder_layers': 2,
- 'encoder_embed_dim': 512,
- 'encoder_ffn_embed_dim': 1024,
- 'encoder_attention_heads': 8,
- 'dilated_ratio': '[1, 2, 4]',
- 'segment_length': '[512, 1024, 2048]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_2_layers_768_dim = {
- 'encoder_layers': 2,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_3_layers_768_dim = {
- 'encoder_layers': 3,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_6_layers_768_dim = {
- 'encoder_layers': 6,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 4096, 8192, 16384, 65536]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_8_layers_768_dim = {
- 'encoder_layers': 8,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_768_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_8_layers_1024_dim = {
- 'encoder_layers': 8,
- 'encoder_embed_dim': 1024,
- 'encoder_ffn_embed_dim': 4096,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- #'segment_length': '[512, 1024, 2048, 4096, 8192]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_2_layers_1024_dim = {
- 'encoder_layers': 2,
- 'encoder_embed_dim': 1024,
- 'encoder_ffn_embed_dim': 4096,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- #'segment_length': '[512, 1024, 2048, 4096, 8192]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-
-LongNet_24_layers_1024_dim = {
- 'encoder_layers': 24,
- 'encoder_embed_dim': 1024,
- 'encoder_ffn_embed_dim': 4096,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- #'segment_length': '[512, 1024, 2048, 4096, 8192]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_3_layers_1536_dim = {
- 'encoder_layers': 3,
- 'encoder_embed_dim': 1536,
- 'encoder_ffn_embed_dim': 6144,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_6_layers_1536_dim = {
- 'encoder_layers': 6,
- 'encoder_embed_dim': 1536,
- 'encoder_ffn_embed_dim': 6144,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_8_layers_1536_dim = {
- 'encoder_layers': 8,
- 'encoder_embed_dim': 1536,
- 'encoder_ffn_embed_dim': 6144,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- #'segment_length': '[512, 1024, 2048, 4096, 8192]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_12_layers_1536_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 1536,
- 'encoder_ffn_embed_dim': 6144,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- #'segment_length': '[512, 1024, 2048, 4096, 8192]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_Vanilla_12_layers_256_dim = {
- 'encoder_layers': 12,
- 'encoder_embed_dim': 256,
- 'encoder_ffn_embed_dim': 512,
- 'encoder_attention_heads': 8,
- 'dilated_ratio': '[1]',
- 'segment_length': '[10000000]',
- 'block_shift': False,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_Vanilla_6_layers_768_dim = {
- 'encoder_layers': 6,
- 'encoder_embed_dim': 768,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1]',
- 'segment_length': '[10000000]',
- 'block_shift': False,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_Vanilla_6_layers_1536_dim = {
- 'encoder_layers': 6,
- 'encoder_embed_dim': 1536,
- 'encoder_ffn_embed_dim': 6144,
- 'encoder_attention_heads': 16,
- 'dilated_ratio': '[1]',
- 'segment_length': '[10000000]',
- 'block_shift': False,
- 'flash_attention': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-LongNet_test = {
- 'encoder_layers': 1,
- 'encoder_embed_dim': 192,
- 'encoder_ffn_embed_dim': 192,
- 'encoder_attention_heads': 8,
- 'dilated_ratio': '[1, 2, 4]',
- 'segment_length': '[512, 1024, 2048]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0
-}
-
-
-LongNet_3_layers_512_dim = {
- 'encoder_layers': 3,
- 'encoder_embed_dim': 512,
- 'encoder_ffn_embed_dim': 3072,
- 'encoder_attention_heads': 16,
- 'drop_path_rate': 0.1,
- 'dilated_ratio': '[1, 2, 4, 8, 16]',
- 'segment_length': '[1024, 2048, 4096, 8192, 16384]',
- 'flash_attention': True,
- 'block_shift': True,
- 'use_xmoe': False,
- 'moe_top1_expert': False,
- 'moe_freq': 0,
- 'moe_expert_count': 0,
- }
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/LongNetVit.py b/code/xtuner/model/torchscale/model/LongNetVit.py
deleted file mode 100644
index 8301998522be16d8fd534493642272bd9c94dc0a..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetVit.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
-# DeiT: https://github.com/facebookresearch/deit
-# MAE: https://github.com/facebookresearch/mae
-# --------------------------------------------------------
-#
-# Portions Copyright Prov-GigaPath
-# Original File: https://github.com/facebookresearch/mae
-
-from functools import partial
-
-import os
-import sys
-import torch
-import torch.nn as nn
-import numpy as np
-
-import timm
-from timm.models.registry import register_model
-import huggingface_hub
-
-from xtuner.model.torchscale.model.pos_embed import get_2d_sincos_pos_embed
-from xtuner.model.torchscale.model.LongNet import make_longnet_from_name
-from xtuner.registry import BUILDER
-
-class PatchEmbed(nn.Module):
- """Slide Patch Embedding"""
-
- def __init__(
- self,
- in_chans=1536,
- embed_dim=768,
- norm_layer=None,
- bias=True,
- ):
- super().__init__()
-
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
-
- def forward(self, x):
- # print(x.dtype, self.proj.weight.dtype)
- B, L, D = x.shape
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-
-class LongNetViT(nn.Module):
- """
- Backbone of Vision Transformer for downstream tasks
-
- Arguments:
- ----------
- in_chans: int
- The number of input channels, should be the tile encoding dimension 1536.
- embed_dim: int
- The embedding dimension of the LongNet model.
- depth: int
- The number of LongNet layers in the LongNet model.
- slide_ngrids: int
- The number of grids in the slide.
- tile_size: int
- The tile size. Default is 256px.
- max_wsi_size: int
- The maximum size of the WSI.
- norm_layer: nn.LayerNorm
- The normalization layer used in the model.
- global_pool: bool
- Whether to use global pooling or not.
- dropout: float
- The dropout rate used in the model.
- drop_path_rate: float
- The drop path rate used in the model.
- """
-
- def __init__(self,
- in_chans=512,
- embed_dim=256,
- depth=12,
- slide_ngrids=1000,
- tile_size=224,
- max_wsi_size = 420096,
- norm_layer=nn.LayerNorm,
- global_pool=False,
- dropout=0.1,
- drop_path_rate=0.1,
- token_norm = False,
- **kwargs):
- super().__init__()
-
- # --------------------------------------------------------------------------
- # MAE encoder specifics
- self.patch_embed = PatchEmbed(in_chans, embed_dim)
-
- self.embed_dim = embed_dim
-
- self.tile_size = tile_size
- self.slide_ngrids = slide_ngrids
- num_patches = slide_ngrids**2
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- self.register_buffer('pos_embed', torch.zeros(1, num_patches + 1, embed_dim), persistent=False) # fixed sin-cos embedding
-
- self.encoder_name = "LongNet_{}_layers_{}_dim".format(depth, embed_dim)
- if kwargs.get("mlp_ratio", 4.0) != 4.0:
- self.encoder_name += "_mlp{}".format(kwargs.get("mlp_ratio"))
-
- # get optimal segment length
- segment_length = self.get_optimal_segment_length(max_wsi_size, tile_size)
- self.encoder = make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout, segment_length=segment_length)
- self.norm = norm_layer(embed_dim)
- # --------------------------------------------------------------------------
-
- self.global_pool = global_pool
- self.token_norm = token_norm
- print("Global Pooling:", self.global_pool)
-
- self.initialize_vit_weights()
-
-
- def initialize_vit_weights(self):
- # initialization
- # initialize (and freeze) pos_embed by sin-cos embedding
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=True)
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
-
- # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
- w = self.patch_embed.proj.weight.data
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
-
- # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
- torch.nn.init.normal_(self.cls_token, std=0.02)
-
- # initialize nn.Linear and nn.LayerNorm
- self.apply(self._init_weights)
-
- def get_optimal_segment_length(self, max_wsi_size: int=262144, tile_size: int=256) -> str:
- '''
- Get the optimal segment length based on the maximum image size and tile size.
-
- Arguments:
- ----------
- max_wsi_size: int
- The maximum size of the WSI.
- tile_size: int
- The tile size.
- '''
- max_seq_len = (max_wsi_size // tile_size) ** 2
- # calculate the segment length
- segment_length = np.linspace(np.log2(1024), int(np.log2(max_seq_len)), 5)
- segment_length = np.power(2, segment_length).astype(int)
- # convert to str format
- segment_length = str(list(segment_length))
- return segment_length
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- # we use xavier_uniform following official JAX ViT:
- torch.nn.init.xavier_uniform_(m.weight)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- def coords_to_pos(self, coords, tile_size: int = 256):
- """
- This function is used to convert the coordinates to the positional indices
-
- Arguments:
- ----------
- coords: torch.Tensor
- The coordinates of the patches, of shape [N, L, 2]
- output: torch.Tensor
- The positional indices of the patches, of shape [N, L]
- """
- coords_ = torch.floor(coords / tile_size)
- pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
- return pos.long() + 1 # add 1 for the cls token
-
- def forward(self, x, coords, all_layer_embed=False):
- """
- The forward pass of the model
-
- Arguments:
- ----------
- x: torch.Tensor
- The input tile embeddings, of shape [N, L, D]
- coords: torch.Tensor
- The coordinates of the patches, of shape [N, L, 2]
- all_layer_embed: bool
- Whether to return embeddings from all layers or not
- """
- # embed patches
- x = self.patch_embed(x)
-
- # get pos indices
- pos = self.coords_to_pos(coords, self.tile_size) # [N, L]
-
- x = x + self.pos_embed[:, pos, :].squeeze(0)
-
- # append cls token
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
-
- # apply Transformer blocks
- if all_layer_embed:
- x_list = self.encoder(src_tokens=None, token_embeddings=x, return_all_hiddens=all_layer_embed)["encoder_states"]
- else:
- x_list = [self.encoder(src_tokens=None, token_embeddings=x)["encoder_out"]]
-
- outcomes = []
- for x in x_list:
- if self.global_pool:
- x = x[:, 1:, :].mean(dim=1) # global average pooling
- outcome = self.norm(x)
- elif self.token_norm:
- outcome = self.norm(x)
- else:
- x = self.norm(x)
- outcome = x[:, 0]
- outcomes.append(outcome)
-
- return outcomes
-
-
-
-
-def _prune_mismatched_keys(state_dict: dict, model: torch.nn.Module):
- model_sd = model.state_dict()
- to_drop = []
- for k, v in list(state_dict.items()):
- if k in model_sd and model_sd[k].shape != v.shape:
- to_drop.append(k)
- for k in to_drop:
- state_dict.pop(k)
- if to_drop:
- print("\033[93mSkipping incompatible keys:\033[00m")
- for k in to_drop:
- print(" -", k)
- return state_dict
-
-def _reinit_first_layer(model: torch.nn.Module, method: str = "xavier_uniform", seed: int | None = None):
- """Reinitialize patch_embed.proj after skipping it from the checkpoint."""
- m: nn.Linear = model.patch_embed.proj
- if seed is not None:
- torch.manual_seed(seed)
- if method == "xavier_uniform":
- nn.init.xavier_uniform_(m.weight)
- elif method == "xavier_normal":
- nn.init.xavier_normal_(m.weight)
- elif method == "kaiming_uniform":
- nn.init.kaiming_uniform_(m.weight, nonlinearity="linear")
- elif method == "kaiming_normal":
- nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
- else:
- nn.init.normal_(m.weight, std=0.02)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- print("\033[96mReinitialized first layer: patch_embed.proj ({})\033[00m".format(method))
-
-def _sync_first_layer_dtype_device(model: torch.nn.Module):
- """
- Ensure patch_embed.proj has the SAME dtype & device as the rest of the model.
- Uses the first non-proj parameter as reference.
- """
- proj: nn.Linear = model.patch_embed.proj
- # Find a reference parameter that is not the proj weight/bias
- ref_param = None
- for p in model.parameters():
- if p is not proj.weight and (proj.bias is None or p is not proj.bias):
- ref_param = p
- break
- if ref_param is None:
- return # degenerate case
-
- # Move ONLY the first layer to the reference dtype/device
- proj.to(device=ref_param.device, dtype=ref_param.dtype)
- print(f"\033[96mSynchronized first layer dtype/device → {ref_param.dtype} @ {ref_param.device}\033[00m")
-
-
-def create_model_original(pretrained: str, model_arch: str, in_chans: int, local_dir: str = os.path.join(os.path.expanduser("~"), ".cache/"), **kwargs):
- model = timm.create_model(model_arch, pretrained=False, in_chans=in_chans, **kwargs)
-
- if pretrained.startswith("hf_hub:"):
- hub_name = pretrained.split(":")[1]
- huggingface_hub.hf_hub_download(hub_name, filename="slide_encoder.pth", local_dir=local_dir, force_download=True)
- local_path = os.path.join(local_dir, "slide_encoder.pth")
- else:
- local_path = pretrained
-
- if os.path.exists(local_path):
- state_dict = torch.load(local_path, map_location="cpu")["model"]
-
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
- if len(missing_keys) > 0:
- for k in missing_keys:
- print("Missing ", k)
-
- if len(unexpected_keys) > 0:
- for k in unexpected_keys:
- print("Unexpected ", k)
-
- print("\033[92m Successfully Loaded Pretrained GigaPath model from {} \033[00m".format(pretrained))
- else:
- print("\033[93m Pretrained weights not found at {}. Randomly initialized the model! \033[00m".format(local_path))
-
- return model
-
-def create_model(
- pretrained: str,
- model_arch: str,
- in_chans: int,
- local_dir: str = os.path.join(os.path.expanduser("~"), ".cache/"),
- reinit_method: str = "xavier_uniform",
- reinit_seed: int | None = None,
- map_location = 'cpu',
- **kwargs
-):
- model = timm.create_model(model_arch, pretrained=False, in_chans=in_chans, **kwargs)
-
- # Resolve checkpoint path
- if pretrained.startswith("hf_hub:"):
- hub_name = pretrained.split(":")[1]
- huggingface_hub.hf_hub_download(hub_name, filename="slide_encoder.pth", local_dir=local_dir, force_download=True)
- local_path = os.path.join(local_dir, "slide_encoder.pth")
- else:
- local_path = pretrained
-
- if os.path.exists(local_path):
- # Safer load when available
- try:
- obj = torch.load(local_path, map_location=map_location, weights_only=True)
- except TypeError:
- obj = torch.load(local_path, map_location=map_location)
- state_dict = obj.get("model", obj)
-
- # 1) Explicitly drop first layer (void copy)
- for k in ["patch_embed.proj.weight", "patch_embed.proj.bias"]:
- if k in state_dict:
- state_dict.pop(k)
-
- # 2) Drop any other mismatched tensors for robustness
- state_dict = _prune_mismatched_keys(state_dict, model)
-
- # 3) Load remaining weights
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
-
- if missing_keys:
- print("\033[93mMissing keys (random/init):\033[00m")
- for k in missing_keys:
- print(" -", k)
- if unexpected_keys:
- print("\033[93mUnexpected keys (ignored):\033[00m")
- for k in unexpected_keys:
- print(" -", k)
-
- # 4) Re-initialize the first layer explicitly
- _reinit_first_layer(model, method=reinit_method, seed=reinit_seed)
- _sync_first_layer_dtype_device(model)
-
- print("\033[92mLoaded pretrained weights with first layer skipped & reinitialized from {}\033[00m".format(pretrained))
- else:
- print("\033[93mPretrained weights not found at {}. Model is randomly initialized.\033[00m".format(local_path))
- # Ensure consistent init for the first layer anyway
- _reinit_first_layer(model, method=reinit_method, seed=reinit_seed)
-
- return model
-
-@register_model
-def gigapath_slide_enc12l768d(**kwargs):
- model = LongNetViT(embed_dim=768, depth=12, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
- return model
-
-@BUILDER.register_module()
-@register_model
-def gigapath_slide_enc3l1536d(**kwargs):
- model = LongNetViT(embed_dim=1536, depth=3, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
- return model
-
-@register_model
-def gigapath_slide_enc24l1024d(**kwargs):
- model = LongNetViT(embed_dim=1024, depth=4, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
- return model
-
-@register_model
-def gigapath_slide_enc12l1536d(**kwargs):
- model = LongNetViT(embed_dim=1536, depth=12, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
- return model
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/__init__.py b/code/xtuner/model/torchscale/model/LongNetWithMerging/__init__.py
deleted file mode 100644
index 00bba8dbcd18308a93af53574f33d56fda5cad13..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetWithMerging/__init__.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from .swin_longnet_encoder import LongNetWithPatchMerging
-from torchscale.model import LongNetConfig as longnet_arch
-from torchscale.architecture.config import EncoderConfig
-
-def make_longnet(args):
- if args.arch in longnet_arch.__dict__.keys():
- longnet_args = longnet_arch.__dict__[args.arch]
- if hasattr(args, 'dropout'):
- longnet_args['dropout'] = args.dropout
- if hasattr(args, 'drop_path_rate'):
- longnet_args['drop_path_rate'] = args.drop_path_rate
- longnet_args = EncoderConfig(**longnet_args)
- model = LongNetWithPatchMerging(longnet_args)
- print('Number of trainable LongNet parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
- return model
-
-
-
-def make_swin_longnet_from_name(config_name: str,
- dilated_ratio: str='[1, 2, 4, 8, 16]',
- segment_length: str='[1024, 2048, 4096, 8192, 16384]',
- drop_path_rate: int=0.1,
- dropout: float=0.1,
- enable_gradient_checkpoint = True,
- keep_dim_after_merge = True,
- merge_size = 2,
-
- use_rel_pos_2d = False
- ):
- '''
- make LongNet model from config name
-
- Arguments:
- ----------
- config_name: str
- name of the config
- dilated_ratio: str
- dilated ratio
- segment_length: str
- segment length
- drop_path_rate: int
- drop path rate
- dropout: float
- dropout rate
- '''
- if config_name in longnet_arch.__dict__.keys():
- longnet_args = longnet_arch.__dict__[config_name]
-
- longnet_args = longnet_arch.__dict__[config_name]
-
- longnet_args['dropout'] = dropout
- longnet_args['drop_path_rate'] = drop_path_rate
-
- # set dilated ratio and segment length
- longnet_args['dilated_ratio'] = dilated_ratio
- longnet_args['segment_length'] = segment_length
- longnet_args['checkpoint_activations'] = enable_gradient_checkpoint
-
- print('dilated_ratio: ', dilated_ratio)
- print('segment_length: ', segment_length)
-
- longnet_args = EncoderConfig(**longnet_args)
-
- model = LongNetWithPatchMerging(longnet_args,
- keep_dim_after_merge=keep_dim_after_merge,
- merge_size= merge_size,
- use_rel_pos_2d=use_rel_pos_2d
- )
- print('Number of trainable LongNet parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
- return model
-
-__all__ = [
- 'make_swin_longnet_from_name'
-]
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 01c8afd23262bd7d521569c59aaae8613a3ef43e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/relative_position_bias_2d.cpython-311.pyc b/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/relative_position_bias_2d.cpython-311.pyc
deleted file mode 100644
index 04ce3578187c86a31255173e93c94f210c102e86..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/relative_position_bias_2d.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/sparse_patch_merging.cpython-311.pyc b/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/sparse_patch_merging.cpython-311.pyc
deleted file mode 100644
index e050c5d653555d7a210b089fa1e6f04cb1a755a8..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/sparse_patch_merging.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/swin_longnet_encoder.cpython-311.pyc b/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/swin_longnet_encoder.cpython-311.pyc
deleted file mode 100644
index 820fb6f8cee6a341515c46ae809e503db62445a3..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/LongNetWithMerging/__pycache__/swin_longnet_encoder.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/relative_position_bias_2d.py b/code/xtuner/model/torchscale/model/LongNetWithMerging/relative_position_bias_2d.py
deleted file mode 100644
index c2050875467d0ec26e0568fe6cd3df44f9b4e5ee..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetWithMerging/relative_position_bias_2d.py
+++ /dev/null
@@ -1,73 +0,0 @@
-
-from typing import Optional
-import torch
-import torch.nn as nn
-
-class RelativePositionBias2D(nn.Module):
- def __init__(
- self,
- n_heads: int,
- num_buckets: int = 32,
- max_distance: int = 128,
- bidirectional: bool = True,
- ):
- super().__init__()
- self.n_heads = n_heads
- self.num_buckets = num_buckets
- self.max_distance = max_distance
- self.bidirectional = bidirectional
-
- self.rpb_x = nn.Embedding(num_buckets, n_heads)
- self.rpb_y = nn.Embedding(num_buckets, n_heads)
- nn.init.zeros_(self.rpb_x.weight)
- nn.init.zeros_(self.rpb_y.weight)
-
- @torch.no_grad()
- def _relative_position_bucket(self, relative_position: torch.Tensor) -> torch.Tensor:
- num_buckets = self.num_buckets
- max_distance = self.max_distance
- n = -relative_position
- if self.bidirectional:
- num_buckets //= 2
- sign = (n < 0).to(torch.long)
- n = n.abs()
- else:
- sign = None
- n = torch.clamp(n, min=0)
-
- max_exact = num_buckets // 2
- is_small = n < max_exact
- val_if_large = max_exact + (
- (torch.log(n.float() / max_exact + 1e-6) / torch.log(torch.tensor(max_distance / max_exact)))
- * (num_buckets - max_exact)
- ).to(torch.long)
- val_if_large = torch.clamp(val_if_large, max=num_buckets - 1)
- buckets = torch.where(is_small, n, val_if_large)
- if self.bidirectional:
- buckets = buckets + (sign * num_buckets)
- return buckets
-
- def compute_from_coords(self, coords_q: torch.Tensor, coords_k: torch.Tensor) -> torch.Tensor:
- # Accept [L,2] or [B,L,2]; if 3D, use the first (shared) coords
- if coords_q.dim() == 3:
- coords_q = coords_q[0]
- if coords_k.dim() == 3:
- coords_k = coords_k[0]
-
- assert coords_q.dim() == 2 and coords_q.size(-1) == 2
- assert coords_k.dim() == 2 and coords_k.size(-1) == 2
-
- dy = coords_k[:, 0].unsqueeze(0) - coords_q[:, 0].unsqueeze(1)
- dx = coords_k[:, 1].unsqueeze(0) - coords_q[:, 1].unsqueeze(1)
-
- by = self._relative_position_bucket(dy)
- bx = self._relative_position_bucket(dx)
-
- vy = self.rpb_y(by)
- vx = self.rpb_x(bx)
- v = (vx + vy).permute(2, 0, 1).contiguous()
- return v
-
- def forward(self, batch_size: int, coords_q: torch.Tensor, coords_k: torch.Tensor) -> torch.Tensor:
- v = self.compute_from_coords(coords_q, coords_k)
- return v.unsqueeze(0).repeat(batch_size, 1, 1, 1).view(-1, v.size(1), v.size(2))
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/sparse_patch_merging.py b/code/xtuner/model/torchscale/model/LongNetWithMerging/sparse_patch_merging.py
deleted file mode 100644
index 28e7a20ca10fb6654cf4d1586eef60d484c517a0..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetWithMerging/sparse_patch_merging.py
+++ /dev/null
@@ -1,524 +0,0 @@
-
-# from typing import Optional, Dict, Tuple
-# import torch
-# import torch.nn as nn
-# try:
-# from apex.normalization import FusedLayerNorm as LayerNorm
-# except ModuleNotFoundError:
-# from torch.nn import LayerNorm
-
-
-# class SparsePatchMerging(nn.Module):
-# def __init__(
-# self,
-# args,
-# embed_dim: int,
-# layernorm_eps: float,
-# keep_dim: bool = True,
-# merge_size: int = 2,
-# out_dim: Optional[int] = None,
-
-# subln = False,
-# ):
-# super().__init__()
-# self.args = args
-
-# assert merge_size == 2, "Only 2x2 merging is supported (Swin-style)."
-
-# self.embed_dim = embed_dim
-# self.merge_size = merge_size
-# self.keep_dim = keep_dim
-# # self.use_pre_norm = use_pre_norm
-
-# in_dim = embed_dim * (merge_size ** 2)
-# if out_dim is None:
-# out_dim = embed_dim if keep_dim else (2 * embed_dim)
-# self.out_dim = out_dim
-
-# self.layer_norm = LayerNorm(in_dim, eps=layernorm_eps) if subln else None
-# self.reduction = nn.Linear(in_dim, out_dim)
-
-
-
-# def reset_parameters(self):
-# # self.fc1.reset_parameters()
-# nn.init.xavier_uniform_(self.reduction.weight)
-# if self.layer_norm is not None:
-# self.layer_norm.reset_parameters()
-
-# def forward(
-# self,
-# x: torch.Tensor, # [B, L, C]
-# coords_rc: torch.Tensor, # [L, 2] (row, col) integer indices OR raw coords (we will rank)
-# padmask: Optional[torch.Tensor] # [B, L] bool
-# ):
-# B, L, C = x.shape
-# device = x.device
-
-# if coords_rc.dtype != torch.long and coords_rc.dtype != torch.int64:
-# coords_rc = coords_rc.to(torch.long)
-# rows = coords_rc[:, 0]
-# cols = coords_rc[:, 1]
-# rows_sorted = rows.unique(sorted=True)
-# cols_sorted = cols.unique(sorted=True)
-# row_idx = torch.searchsorted(rows_sorted, rows)
-# col_idx = torch.searchsorted(cols_sorted, cols)
-# rc_idx = torch.stack([row_idx, col_idx], dim=-1) # [L, 2]
-
-# r_group = rc_idx[:, 0] // self.merge_size
-# c_group = rc_idx[:, 1] // self.merge_size
-# group_pairs = torch.stack([r_group, c_group], dim=-1) # [L, 2]
-# unique_groups, inverse = torch.unique(group_pairs, dim=0, return_inverse=True)
-
-# G = unique_groups.size(0)
-# M = self.merge_size * self.merge_size # 4
-
-# # print("SparsePatchMerging: Merging {} tokens into {} groups.".format(L, G))
-# # print('inverse:', inverse)
-
-# r_off = rc_idx[:, 0] % self.merge_size
-# c_off = rc_idx[:, 1] % self.merge_size
-# offset_id = r_off * self.merge_size + c_off # [L] in {0,1,2,3}
-# slot_idx = inverse * M + offset_id # [L] in [0, G*M)
-
-# # print('slot_idx shape:', slot_idx.shape)
-# merged_slots = x.new_zeros(B, G * M, C)
-# for b in range(B):
-# merged_slots[b].index_add_(0, slot_idx, x[b])
-
-# merged_slots = merged_slots.view(B, G, M, C)
-# merged_concat = merged_slots.flatten(2) # [B, G, 4*C]
-# if self.layer_norm is not None:
-# out = self.reduction(self.layer_norm(merged_concat)) # [B, G, C_out]
-# else:
-# out = self.reduction(merged_concat)
-
-# # recalculate the padding mask
-# if padmask is not None:
-# valid = (~padmask).to(x.dtype) # [B, L] in {0,1}
-# group_valid = valid.new_zeros(B, G * M)
-# for b in range(B):
-# group_valid[b].index_add_(0, slot_idx, valid[b])
-# group_valid_any = group_valid.view(B, G, M).sum(dim=-1) > 0 # [B, G] bool
-# padmask_out = ~group_valid_any
-# else:
-# padmask_out = None
-
-# aux = {
-# "unique_groups": unique_groups, # [G, 2] rows//2, cols//2
-# "slot_idx": slot_idx, # [L]
-# "inverse": inverse, # [L] group index per token
-# }
-
-# coords_out = unique_groups
-# return out, coords_out, padmask_out, aux
-
-
-# from typing import Optional
-# import torch
-# import torch.nn as nn
-# try:
-# from apex.normalization import FusedLayerNorm as LayerNorm
-# except ModuleNotFoundError:
-# from torch.nn import LayerNorm
-
-
-# class SparsePatchMerging(nn.Module):
-# def __init__(
-# self,
-# args,
-# embed_dim: int,
-# layernorm_eps: float,
-# keep_dim: bool = True,
-# merge_size: int = 2,
-# out_dim: Optional[int] = None,
-# subln: bool = False,
-# # NEW: cache-control knobs
-# clear_after_forward: bool = False, # clear CUDA cache after each forward
-# clear_on_oom: bool = True, # clear & retry once on OOM
-# ):
-# super().__init__()
-# self.args = args
-# self.clear_after_forward = clear_after_forward
-# self.clear_on_oom = clear_on_oom
-# self._retrying = False # guard to avoid infinite recursion on OOM
-
-# assert merge_size == 2, "Only 2x2 merging is supported (Swin-style)."
-
-# self.embed_dim = embed_dim
-# self.merge_size = merge_size
-# self.keep_dim = keep_dim
-
-# in_dim = embed_dim * (merge_size ** 2)
-# if out_dim is None:
-# out_dim = embed_dim if keep_dim else (2 * embed_dim)
-# self.out_dim = out_dim
-
-# self.layer_norm = LayerNorm(in_dim, eps=layernorm_eps) if subln else None
-# self.reduction = nn.Linear(in_dim, out_dim)
-
-# def _clear_cuda_cache(self):
-# if torch.cuda.is_available():
-# # Release cached blocks back to the driver, and collect any stray IPC handles.
-# torch.cuda.empty_cache()
-# try:
-# torch.cuda.ipc_collect()
-# except Exception:
-# # ipc_collect may not be available on some builds/devices; ignore.
-# pass
-
-# def reset_parameters(self):
-# nn.init.xavier_uniform_(self.reduction.weight)
-# if self.layer_norm is not None:
-# self.layer_norm.reset_parameters()
-
-# def _forward_impl(
-# self,
-# x: torch.Tensor, # [B, L, C]
-# coords_rc: torch.Tensor, # [L, 2] or [B, L, 2]
-# padmask: Optional[torch.Tensor] # [B, L] bool
-# ):
-# B, L, C = x.shape
-# device = x.device
-
-# # --- NEW: normalize coords_rc to [L,2] ---
-# if coords_rc.dim() == 3:
-# if coords_rc.size(0) != B:
-# raise ValueError(f"coords batch dim mismatch: got {coords_rc.size(0)} but inputs have B={B}")
-# if B == 1:
-# coords_rc = coords_rc[0]
-# else:
-# if not torch.equal(coords_rc, coords_rc[0].unsqueeze(0).expand_as(coords_rc)):
-# raise NotImplementedError(
-# "Per-example coords (varying across batch) are not supported by the current "
-# "merging implementation. Use batch size 1 or share coords across the batch."
-# )
-# coords_rc = coords_rc[0]
-# elif coords_rc.dim() != 2:
-# raise ValueError("coords_rc must be [L,2] or [B,L,2].")
-# if coords_rc.size(-1) != 2:
-# raise ValueError("coords_rc last dimension must be 2.")
-# # --- end NEW ---
-
-# if coords_rc.dtype not in (torch.long, torch.int64):
-# coords_rc = coords_rc.to(torch.long)
-
-# # Pure index math; ensure no autograd tracking (saves tiny overhead).
-# with torch.no_grad():
-# rows = coords_rc[:, 0]
-# cols = coords_rc[:, 1]
-# rows_sorted = rows.unique(sorted=True)
-# cols_sorted = cols.unique(sorted=True)
-# row_idx = torch.searchsorted(rows_sorted, rows)
-# col_idx = torch.searchsorted(cols_sorted, cols)
-# rc_idx = torch.stack([row_idx, col_idx], dim=-1) # [L, 2]
-
-# r_group = rc_idx[:, 0] // self.merge_size
-# c_group = rc_idx[:, 1] // self.merge_size
-# group_pairs = torch.stack([r_group, c_group], dim=-1) # [L, 2]
-# unique_groups, inverse = torch.unique(group_pairs, dim=0, return_inverse=True)
-
-# G = unique_groups.size(0)
-# M = self.merge_size * self.merge_size # 4
-
-# r_off = rc_idx[:, 0] % self.merge_size
-# c_off = rc_idx[:, 1] % self.merge_size
-# offset_id = r_off * self.merge_size + c_off # [L] in {0,1,2,3}
-# slot_idx = (inverse * M + offset_id).to(device) # ensure device alignment
-
-# # Build merged slots
-# merged_slots = x.new_zeros(B, G * M, C)
-# for b in range(B):
-# merged_slots[b].index_add_(0, slot_idx, x[b])
-
-# merged_slots = merged_slots.view(B, G, M, C)
-# merged_concat = merged_slots.flatten(2) # [B, G, 4*C]
-# if self.layer_norm is not None:
-# out = self.reduction(self.layer_norm(merged_concat)) # [B, G, C_out]
-# else:
-# out = self.reduction(merged_concat)
-
-# # recalculate the padding mask
-# padmask_out = None
-# if padmask is not None:
-# valid = (~padmask).to(x.dtype) # [B, L] in {0,1}
-# group_valid = valid.new_zeros(B, G * M)
-# for b in range(B):
-# group_valid[b].index_add_(0, slot_idx, valid[b])
-# group_valid_any = group_valid.view(B, G, M).sum(dim=-1) > 0 # [B, G] bool
-# padmask_out = ~group_valid_any
-
-# # Prepare aux (kept lightweight; none of these require grad)
-# aux = {
-# "unique_groups": unique_groups.to(device), # [G, 2]
-# "slot_idx": slot_idx, # [L]
-# "inverse": inverse.to(device), # [L]
-# }
-
-# # Proactively drop large temporaries before returning
-# del merged_slots, merged_concat
-# if padmask is not None:
-# del group_valid
-# del group_valid_any
-# del valid
-
-# # Optional allocator cleanup after forward
-# if self.clear_after_forward and x.is_cuda:
-# self._clear_cuda_cache()
-
-# coords_out = unique_groups
-# return out, coords_out, padmask_out, aux
-
-# def forward(
-# self,
-# x: torch.Tensor,
-# coords_rc: torch.Tensor,
-# padmask: Optional[torch.Tensor]
-# ):
-# try:
-# return self._forward_impl(x, coords_rc, padmask)
-# except RuntimeError as e:
-# msg = str(e).lower()
-# is_cuda_oom = ("cuda out of memory" in msg) or ("cublas" in msg and "alloc" in msg)
-# if self.clear_on_oom and is_cuda_oom and x.is_cuda and not self._retrying:
-# # Clear cache and retry once
-# self._retrying = True
-# self._clear_cuda_cache()
-# torch.cuda.reset_peak_memory_stats()
-# try:
-# return self._forward_impl(x, coords_rc, padmask)
-# finally:
-# self._retrying = False
-# raise
-
-
-# sparse_patch_merging.py
-from typing import Optional, Tuple
-import torch
-import torch.nn as nn
-
-class SparsePatchMerging(nn.Module):
- """
- Stable 2x2 (stride-2) patch merging for Swin/LongNet with:
- - Pre/Post LayerNorm (eps small) for bf16 stability
- - Deterministic TL/TR/BL/BR concat ordering
- - FP32 reductions for scatter/index_add
- - Variance-preserving "sum" option
- - Robust handling of ragged tokens via coords_rc + padmask
- - Optional clear-on-OOM retry
-
- forward(x: [B,L,C], coords_rc: [B,L,2] or [L,2], padmask: Optional[bool[B,L]])
- -> (x_merged: [B,L_out,C'], coords_merged: [B,L_out,2], padmask_merged: [B,L_out])
- """
- def __init__(
- self,
- args,
- embed_dim: int,
- layernorm_eps: float,
- keep_dim: bool = True,
- merge_size: int = 2,
- out_dim: Optional[int] = None,
- pre_norm: bool = True,
- post_norm: bool = True,
- fp32_reduce: bool = True,
- mode: str = 'concat', # 'concat' or 'sum'
- clear_after_forward: bool = False,
- clear_on_oom: bool = True,
- **kwargs,
- ) -> None:
- super().__init__()
- assert merge_size == 2, "Only 2x2 merging supported."
- self.args = args
- self.embed_dim = embed_dim
- self.merge_size = merge_size
- self.keep_dim = keep_dim
- self.clear_after_forward = clear_after_forward
- self.clear_on_oom = clear_on_oom
- self._retrying = False
-
- if out_dim is None:
- out_dim = embed_dim if keep_dim else (2 * embed_dim)
- self.out_dim = out_dim
-
- self.pre_norm = pre_norm
- self.post_norm = post_norm
- self.fp32_reduce = fp32_reduce
- self.mode = mode
-
- eps = layernorm_eps if layernorm_eps is not None else 1e-6
- self.ln_in = nn.LayerNorm(embed_dim, eps=eps) if pre_norm else None
- in_linear = embed_dim * 4 if mode == 'concat' else embed_dim
- # change the reduction to MLP instead of single linear layer
- # self.reduction1 = nn.Sequential([
- # nn.Linear(in_linear, in_linear//2),
- # nn.GELU(),
- # nn.Linear(in_linear // 2, out_dim)
- # ])
- self.reduction1 = nn.Linear(in_linear, in_linear//2)
- self.act = nn.GELU()
- self.reduction2 = nn.Linear(in_linear //2, out_dim)
-
- self.ln_out = nn.LayerNorm(out_dim, eps=eps) if post_norm else None
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
-
- nn.init.xavier_uniform_(self.reduction1.weight)
- nn.init.xavier_uniform_(self.reduction2.weight)
-
- if self.reduction1.bias is not None:
- nn.init.zeros_(self.reduction1.bias)
- if self.reduction2.bias is not None:
- nn.init.zeros_(self.reduction2.bias)
-
- if self.ln_in is not None:
- self.ln_in.reset_parameters()
- if self.ln_out is not None:
- self.ln_out.reset_parameters()
-
-
- def _clear_cuda_cache(self):
- try:
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- except Exception:
- pass
-
- @staticmethod
- def _ensure_batched_coords(coords_rc: torch.Tensor, B: int) -> torch.Tensor:
- if coords_rc.dim() == 2:
- coords_rc = coords_rc.unsqueeze(0).expand(B, -1, -1)
- return coords_rc
-
- def _forward_impl(
- self,
- x: torch.Tensor, # [B, L, C]
- coords_rc: torch.Tensor, # [B, L, 2] or [L, 2]
- padmask: Optional[torch.Tensor] # [B, L] bool (True=pad)
- ):
- B, L, C = x.shape
- coords_rc = self._ensure_batched_coords(coords_rc, B)
- assert coords_rc.shape[:2] == (B, L)
-
- if self.ln_in is not None:
- x = self.ln_in(x)
-
- x_dtype = x.dtype
- red_dtype = torch.float32 if self.fp32_reduce else x.dtype
-
- out_x_list = []
- out_coords_list = []
-
- key_stride = int(2**20) # enough for typical H,W<1e6
-
- for b in range(B):
- xb = x[b] # [L, C]
- coords_b = coords_rc[b].to(torch.long) # [L, 2]
- valid = torch.ones(L, dtype=torch.bool, device=x.device) if padmask is None else (~padmask[b])
-
- rb = coords_b[:, 0]
- cb = coords_b[:, 1]
- gr = torch.div(rb, self.merge_size, rounding_mode='floor')
- gc = torch.div(cb, self.merge_size, rounding_mode='floor')
- keys = gr * key_stride + gc
-
- # TL=0 (r%2=0,c%2=0), TR=1, BL=2, BR=3 — fixed ordering
- rmod = torch.remainder(rb, self.merge_size)
- cmod = torch.remainder(cb, self.merge_size)
- corner = (rmod << 1) | cmod
-
- sel = valid
- if sel.sum() == 0:
- out_x_list.append(xb.new_zeros(0, self.out_dim))
- out_coords_list.append(coords_b.new_zeros(0, 2))
- continue
-
- keys = keys[sel]
- corner = corner[sel]
- xb_sel = xb[sel]
-
- uniq, inv = torch.unique(keys, sorted=True, return_inverse=True)
- G = uniq.numel()
-
- # recover (gr,gc) from linear keys
- gc_out = torch.remainder(uniq, key_stride)
- gr_out = torch.div(uniq, key_stride, rounding_mode='floor')
- coords_out = torch.stack([gr_out, gc_out], dim=-1) # [G,2]
-
- if self.mode == 'concat':
- # do FP32 accumulation, then cast back
- out_buf = torch.zeros(G, 4, C, device=x.device, dtype=red_dtype)
- counts = torch.zeros(G, 4, device=x.device, dtype=red_dtype)
-
- for k in range(4):
- mask_k = (corner == k)
- if mask_k.any():
- gi = inv[mask_k] # [Nk]
- xk = xb_sel[mask_k].to(red_dtype) # [Nk, C]
- out_buf[:, k, :].index_add_(0, gi, xk)
- counts[:, k].index_add_(0, gi, torch.ones(gi.shape[0], device=x.device, dtype=red_dtype))
-
- # average duplicates; zeros stay zeros for missing corners
- counts_clamped = counts.clamp_min(1.0).unsqueeze(-1) # [G,4,1]
- out_buf = out_buf / counts_clamped
- out_feat = out_buf.reshape(G, 4*C) # [G, 4C]
- out_feat = self.reduction2(self.act(self.reduction1(out_feat.to(x_dtype))))
-
- elif self.mode == 'sum':
- # variance-preserving sum: scale by 1/sqrt(k)
- out_buf = torch.zeros(G, C, device=x.device, dtype=red_dtype)
- counts = torch.zeros(G, 1, device=x.device, dtype=red_dtype)
- gi = inv
- xk = xb_sel.to(red_dtype)
- out_buf.index_add_(0, gi, xk)
- counts.index_add_(0, gi, torch.ones(gi.shape[0], device=x.device, dtype=red_dtype).unsqueeze(-1))
- scale = counts.clamp_min(1.0).sqrt().reciprocal()
- out_feat = out_buf * scale
- out_feat = self.reduction2(self.act(self.reduction1(out_feat.to(x_dtype))))
- else:
- raise ValueError(f"Unknown mode {self.mode}")
-
- if self.ln_out is not None:
- out_feat = self.ln_out(out_feat)
-
- out_x_list.append(out_feat) # [G, out_dim]
- out_coords_list.append(coords_out) # [G, 2]
-
- # pack (pad) to max group length across batch for dense return
- Gmax = max((t.shape[0] for t in out_x_list), default=0)
- out_x = x.new_zeros(B, Gmax, self.out_dim)
- out_coords = coords_rc.new_zeros(B, Gmax, 2)
- out_mask = torch.ones(B, Gmax, dtype=torch.bool, device=x.device) # True=pad
-
- for b in range(B):
- G = out_x_list[b].shape[0]
- out_x[b, :G] = out_x_list[b]
- out_coords[b, :G] = out_coords_list[b]
- out_mask[b, :G] = False
-
- return out_x, out_coords, out_mask
-
- def forward(
- self,
- x: torch.Tensor, # [B, L, C]
- coords_rc: torch.Tensor, # [B, L, 2] or [L, 2]
- padmask: Optional[torch.Tensor] = None
- ):
- try:
- out = self._forward_impl(x, coords_rc, padmask)
- if self.clear_after_forward and x.is_cuda:
- self._clear_cuda_cache()
- return out
- except RuntimeError as e:
- msg = str(e).lower()
- is_cuda_oom = ("out of memory" in msg) or ("cuda" in msg and "alloc" in msg) or ("cublas" in msg and "alloc" in msg)
- if self.clear_on_oom and is_cuda_oom and x.is_cuda and not self._retrying:
- self._retrying = True
- self._clear_cuda_cache()
- torch.cuda.reset_peak_memory_stats()
- try:
- return self._forward_impl(x, coords_rc, padmask)
- finally:
- self._retrying = False
- raise
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/LongNetWithMerging/swin_longnet_encoder.py b/code/xtuner/model/torchscale/model/LongNetWithMerging/swin_longnet_encoder.py
deleted file mode 100644
index b3b24e790cd6dadb4bb17bb5e31ec489e6633f5b..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/LongNetWithMerging/swin_longnet_encoder.py
+++ /dev/null
@@ -1,797 +0,0 @@
-from typing import Optional, List, Dict, Union
-import torch
-import os
-import sys, math
-
-this_file_dir = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(this_file_dir, '../../../'))
-
-from .sparse_patch_merging import SparsePatchMerging
-from .relative_position_bias_2d import RelativePositionBias2D
-from torchscale.architecture.utils import init_bert_params
-from torchscale.component.multiway_network import set_split_position
-from torchscale.architecture.encoder import Encoder, EncoderLayer
-from torchscale.component.dilated_attention import DilatedAttention
-from fairscale.nn import checkpoint_wrapper, wrap
-from mmengine import print_log
-
-
-class LongNetEncoderLayer(EncoderLayer):
- def build_self_attention(self, embed_dim, args):
- return DilatedAttention(
- args,
- embed_dim,
- args.encoder_attention_heads,
- dropout=args.attention_dropout,
- self_attention=True,
- encoder_decoder_attention=False,
- subln=args.subln,
- )
-
-
-class LongNetWithPatchMerging(Encoder):
- def __init__(
- self,
- args,
- embed_tokens=None,
- embed_positions=None,
- output_projection=None,
- is_encoder_decoder=False,
- keep_dim_after_merge: bool = True,
- merge_size: int = 2,
- use_rel_pos_2d: bool = True,
- rel_pos_num_buckets: int = 32,
- rel_pos_max_distance: int = 512,
- # NEW: cache-control knobs (match SparsePatchMerging)
- clear_after_forward: bool = False, # clear CUDA cache after each forward
- clear_on_oom: bool = True, # clear & retry once on CUDA OOM
- **kwargs,
- ):
- self.merge_size = merge_size
- self.keep_dim_after_merge = keep_dim_after_merge
- super().__init__(
- args=args,
- embed_tokens=embed_tokens,
- embed_positions=embed_positions,
- output_projection=output_projection,
- is_encoder_decoder=is_encoder_decoder,
- **kwargs
- )
-
- self.embed_dim = args.encoder_embed_dim
-
-
- # NEW: store cache knobs
- self.clear_after_forward = clear_after_forward
- self.clear_on_oom = clear_on_oom
- self._retrying = False # guard to avoid infinite OOM loops
-
- # build patch merge module
- self.patch_merge = self.build_patch_merge(
- args=args,
- embed_dim=self.embed_dim,
- keep_dim=keep_dim_after_merge,
- merge_size=merge_size,
- )
-
- self.use_rel_pos_2d = use_rel_pos_2d
- self.rel_pos_num_buckets = rel_pos_num_buckets
- self.rel_pos_max_distance = rel_pos_max_distance
-
- if self.use_rel_pos_2d:
- n_heads = args.encoder_attention_heads
- self.relpos2d = RelativePositionBias2D(
- n_heads=n_heads,
- num_buckets=rel_pos_num_buckets,
- max_distance=rel_pos_max_distance,
- bidirectional=True,
- )
- else:
- self.relpos2d = None
-
- if args.bert_init:
- self.apply(init_bert_params)
-
- if args.deepnorm:
- if is_encoder_decoder:
- init_scale = (
- math.pow(
- math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
- )
- / 1.15
- )
- else:
- init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
- for name, p in self.named_parameters():
- if 'reduction' in name or 'reduction1' in name or 'reduction2' in name:
- p.data.mul_(init_scale)
-
- if args.subln:
- if is_encoder_decoder:
- init_scale = math.sqrt(
- math.log(3 * args.decoder_layers)
- * math.log(2 * args.encoder_layers)
- / 3
- )
- else:
- init_scale = math.sqrt(math.log(args.encoder_layers * 2))
- for name, p in self.named_parameters():
- if 'reduction' in name or 'reduction1' in name or 'reduction2' in name:
- p.data.mul_(init_scale)
-
- # NEW: utility to clear CUDA cache
- def _clear_cuda_cache(self):
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- try:
- torch.cuda.ipc_collect()
- except Exception:
- pass
-
- def build_patch_merge(self, args, embed_dim, keep_dim, merge_size):
- merger = SparsePatchMerging(
- args=args,
- embed_dim=embed_dim,
- layernorm_eps=args.layernorm_eps,
- merge_size=merge_size,
- keep_dim=keep_dim,
- subln=args.subln,
- )
- if args.checkpoint_activations:
- merger = checkpoint_wrapper(merger)
- if args.fsdp:
- merger = wrap(merger)
- return merger
-
- def build_encoder_layer(
- self, args, depth, is_moe_layer=False, is_encoder_decoder=False
- ):
-
- layer = LongNetEncoderLayer(
- args,
- depth,
- is_moe_layer=is_moe_layer,
- is_encoder_decoder=is_encoder_decoder,
- )
- if args.checkpoint_activations:
- layer = checkpoint_wrapper(layer)
- if args.fsdp:
- layer = wrap(layer)
- return layer
-
- # NEW: factor original body into _forward_impl so we can OOM-retry & finally-clean
- def _forward_impl(
- self,
- src_tokens,
- encoder_padding_mask=None,
- attn_mask=None,
- return_all_hiddens=False,
- token_embeddings=None,
- multiway_split_position=None,
- features_only=False,
- incremental_state=None,
- positions=None,
- coords: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
- **kwargs,
- ):
- try:
- assert src_tokens is not None or token_embeddings is not None
-
- if encoder_padding_mask is None:
- if src_tokens is not None:
- encoder_padding_mask = torch.zeros_like(
- src_tokens, device=src_tokens.device
- ).bool()
- else:
- encoder_padding_mask = torch.zeros(
- [token_embeddings.size(0), token_embeddings.size(1)],
- device=token_embeddings.device,
- ).bool()
-
- if multiway_split_position is not None:
- assert self.args.multiway
- self.apply(set_split_position(multiway_split_position))
-
- x, encoder_embedding = self.forward_embedding(
- src_tokens, token_embeddings, positions
- )
- x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
-
- B, L, C = x.shape
- # print('input x shape', x.shape)
- encoder_states = []
- if return_all_hiddens:
- encoder_states.append(x)
-
- rel_pos_bias = None
- if self.relative_position is not None:
- rel_pos_bias = self.relative_position(
- batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
- )
-
- if coords is None:
- coords_rc = None
- else:
- # Accept: list[tensor], [L,2] tensor, or [B,L,2] tensor
- coords_t = coords[0] if isinstance(coords, list) else coords
- Bx = x.size(0) # actual batch size of inputs
-
- if not torch.is_tensor(coords_t):
- raise ValueError("coords must be a Tensor or list[Tensor].")
-
- if coords_t.dim() == 2:
- # [L, 2]
- coords_rc = coords_t
- elif coords_t.dim() == 3:
- # [B, L, 2] -> ensure B matches and either B==1 or all examples share coords
- if coords_t.size(0) != Bx:
- raise ValueError(f"coords batch dim mismatch: got {coords_t.size(0)} but inputs have B={Bx}")
- if Bx == 1:
- coords_rc = coords_t[0]
- else:
- # require same coords across the batch (cheap equality check)
- if not torch.equal(coords_t, coords_t[0].unsqueeze(0).expand_as(coords_t)):
- raise NotImplementedError(
- "Per-example coords (varying across batch) are not supported by the current "
- "patch-merging/layout path. Use batch size 1 or share coords across the batch."
- )
- coords_rc = coords_t[0]
- else:
- raise ValueError("coords must have shape [L,2] or [B,L,2].")
-
- if coords_rc.size(-1) != 2:
- raise ValueError("coords last dimension must be 2.")
-
- new_coords = coords_rc
- new_mask = encoder_padding_mask
- num_layers = self.num_layers
- if incremental_state is None:
- incremental_state = [None] * num_layers
-
- l_aux = []
- for idx, layer in enumerate(self.layers):
- rel_pos = None
- if new_coords is not None and self.use_rel_pos_2d:
- rel_pos = self._coords_to_relpos(B, new_coords)
-
- cur_mask = new_mask if incremental_state is None else None
- x, l_aux_i = layer(
- x,
- encoder_padding_mask=cur_mask,
- attn_mask=attn_mask,
- rel_pos=rel_pos,
- multiway_split_position=multiway_split_position,
- )
- # free per-layer rel_pos ASAP
- del rel_pos
-
- if return_all_hiddens:
- encoder_states.append(x)
- l_aux.append(l_aux_i)
-
- # patch merge after activation, except final layer
- if idx < num_layers - 1 and new_coords is not None:
- x, new_coords, new_mask = self.patch_merge(
- x, self._coords_to_rowcol(new_coords), new_mask
- )
- # del _aux # not needed downstream
- # print_log(f'x shape {x.shape}', 'current')
- if self.layer_norm is not None:
- x = self.layer_norm(x)
-
- if not features_only and self.output_projection is not None:
- x = self.output_projection(x)
-
- out = {
- "encoder_out": x,
- "encoder_embedding": encoder_embedding,
- "encoder_padding_mask": encoder_padding_mask,
- "encoder_states": encoder_states,
- "l_aux": l_aux,
- 'coords': new_coords
- }
- return out
- finally:
- # === FINAL CLEANUP ===
- # Explicitly drop coord tensors to help the allocator reclaim earlier.
- if 'coords' in locals():
- del coords
- if 'coords_rc' in locals():
- del coords_rc
- if 'rel_pos_bias' in locals():
- del rel_pos_bias
- # (x/encoder_states are returned, so we can't drop them here.)
-
- def forward(
- self,
- src_tokens,
- encoder_padding_mask=None,
- attn_mask=None,
- return_all_hiddens=False,
- token_embeddings=None,
- multiway_split_position=None,
- features_only=False,
- incremental_state=None,
- positions=None,
- coords: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
- **kwargs,
- ):
- try:
- return self._forward_impl(
- src_tokens,
- encoder_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- return_all_hiddens=return_all_hiddens,
- token_embeddings=token_embeddings,
- multiway_split_position=multiway_split_position,
- features_only=features_only,
- incremental_state=incremental_state,
- positions=positions,
- coords=coords,
- **kwargs,
- )
- except RuntimeError as e:
- msg = str(e).lower()
- is_cuda_oom = ("cuda out of memory" in msg) or ("cublas" in msg and "alloc" in msg)
- if (
- self.clear_on_oom and is_cuda_oom and torch.cuda.is_available()
- and not self._retrying
- ):
- self._retrying = True
- self._clear_cuda_cache()
- try:
- torch.cuda.reset_peak_memory_stats()
- except Exception:
- pass
- try:
- return self._forward_impl(
- src_tokens,
- encoder_padding_mask=encoder_padding_mask,
- attn_mask=attn_mask,
- return_all_hiddens=return_all_hiddens,
- token_embeddings=token_embeddings,
- multiway_split_position=multiway_split_position,
- features_only=features_only,
- incremental_state=incremental_state,
- positions=positions,
- coords=coords,
- **kwargs,
- )
- finally:
- self._retrying = False
- raise
- finally:
- # Optional cache cleanup after a successful (or failed) forward
- if self.clear_after_forward and torch.cuda.is_available():
- self._clear_cuda_cache()
-
- @staticmethod
- def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
- with torch.no_grad():
- x = coords_xy[:, 0]
- y = coords_xy[:, 1]
- x_for_unique = x
- y_for_unique = y
- if x_for_unique.dtype.is_floating_point:
- x_for_unique = x_for_unique.round().to(torch.int)
- y_for_unique = y_for_unique.round().to(torch.int)
- x_sorted = torch.unique(x_for_unique, sorted=True)
- y_sorted = torch.unique(y_for_unique, sorted = True)
-
- col = torch.searchsorted(x_sorted, x)
- row = torch.searchsorted(y_sorted, y)
- return torch.stack([row, col], dim=-1)
-
- def _coords_to_relpos(self, B: int, coords_xy: torch.Tensor):
- rc = self._coords_to_rowcol(coords_xy)
- return self._compute_relpos(B, rc)
-
- def _compute_relpos(self, B: int, coords_rc: torch.Tensor):
- if not self.use_rel_pos_2d:
- return None
- return self.relpos2d(B, coords_rc, coords_rc)
-
-# class LongNetEncoderLayer(EncoderLayer):
-
-# def build_self_attention(self, embed_dim, args):
-# return DilatedAttention(
-# args,
-# embed_dim,
-# args.encoder_attention_heads,
-# dropout=args.attention_dropout,
-# self_attention=True,
-# encoder_decoder_attention=False,
-# subln=args.subln,
-# )
-
-
-# class LongNetWithPatchMerging(Encoder):
-# def __init__(
-# self,
-# args,
-# embed_tokens=None,
-# embed_positions=None,
-# output_projection=None,
-# is_encoder_decoder=False,
-# keep_dim_after_merge: bool = True,
-# merge_size: int = 2,
-# use_rel_pos_2d: bool = True,
-# rel_pos_num_buckets: int = 32,
-# rel_pos_max_distance: int = 512,
-# **kwargs,
-# ):
-# # call the encoder's init
-# super().__init__(
-# args = args,
-# embed_tokens=embed_tokens,
-# embed_positions= embed_positions,
-# output_projection= output_projection,
-# is_encoder_decoder= is_encoder_decoder,
-# **kwargs
-# )
-
-# self.embed_dim = args.encoder_embed_dim
-# # build patch merge module
-# self.merge_size = merge_size
-# self.keep_dim_after_merge = keep_dim_after_merge
-# self.patch_merge = self.build_patch_merge(args = args,
-# embed_dim= self.embed_dim,
-# keep_dim= keep_dim_after_merge,
-# merge_size= merge_size
-# )
-
-
-# self.use_rel_pos_2d = use_rel_pos_2d
-# self.rel_pos_num_buckets = rel_pos_num_buckets
-# self.rel_pos_max_distance = rel_pos_max_distance
-
-# if self.use_rel_pos_2d:
-# n_heads = args.encoder_attention_heads
-# self.relpos2d = RelativePositionBias2D(
-# n_heads=n_heads,
-# num_buckets=rel_pos_num_buckets,
-# max_distance=rel_pos_max_distance,
-# bidirectional=True,
-# )
-# else:
-# self.relpos2d = None
-
-# # do the initialization for patch merger
-# if args.bert_init:
-# self.apply(init_bert_params)
-
-
-# if args.deepnorm:
-# if is_encoder_decoder:
-# init_scale = (
-# math.pow(
-# math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
-# )
-# / 1.15
-# )
-# else:
-# init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
-# for name, p in self.named_parameters():
-# if (
-# 'reduction' in name
-# ):
-# p.data.div_(init_scale)
-
-# if args.subln:
-# if is_encoder_decoder:
-# init_scale = math.sqrt(
-# math.log(3 * args.decoder_layers)
-# * math.log(2 * args.encoder_layers)
-# / 3
-# )
-# else:
-# init_scale = math.sqrt(math.log(args.encoder_layers * 2))
-# for name, p in self.named_parameters():
-# if (
-# 'reduction' in name
-# ):
-# p.data.mul_(init_scale)
-
-
-
-
-# def build_patch_merge(self, args, embed_dim, keep_dim, merge_size):
-# merger = SparsePatchMerging(
-# args = args,
-# embed_dim= embed_dim,
-# layernorm_eps = args.layernorm_eps,
-# merge_size= merge_size,
-# keep_dim = keep_dim,
-# subln= args.subln
-# )
-# if args.checkpoint_activations:
-# merger = checkpoint_wrapper(merger)
-# if args.fsdp:
-# merger = wrap(merger)
-# return merger
-
-# # build longnet encoder layer
-# def build_encoder_layer(
-# self, args, depth, is_moe_layer=False, is_encoder_decoder=False
-# ):
-# layer = LongNetEncoderLayer(
-# args,
-# depth,
-# is_moe_layer=is_moe_layer,
-# is_encoder_decoder=is_encoder_decoder,
-# )
-# if args.checkpoint_activations:
-# layer = checkpoint_wrapper(layer)
-# if args.fsdp:
-# layer = wrap(layer)
-# return layer
-
-# def forward(
-# self,
-# src_tokens,
-# encoder_padding_mask=None,
-# attn_mask=None,
-# return_all_hiddens=False,
-# token_embeddings=None,
-# multiway_split_position=None,
-# features_only=False,
-# incremental_state=None,
-# positions=None,
-
-# # coords from clam framework
-
-# coords: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-# **kwargs
-# ):
-# assert src_tokens is not None or token_embeddings is not None
-
-# if encoder_padding_mask is None:
-# if src_tokens is not None:
-# encoder_padding_mask = torch.zeros_like(
-# src_tokens, device=src_tokens.device
-# ).bool()
-# else:
-# encoder_padding_mask = torch.zeros(
-# [token_embeddings.size(0), token_embeddings.size(1)],
-# device=token_embeddings.device,
-# ).bool()
-
-# if multiway_split_position is not None:
-# assert self.args.multiway
-# self.apply(set_split_position(multiway_split_position))
-
-# x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
-# x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
-
-# B, L, C = x.shape
-# encoder_states = []
-
-# if return_all_hiddens:
-# encoder_states.append(x)
-
-# rel_pos_bias = None
-# if self.relative_position is not None:
-# rel_pos_bias = self.relative_position(
-# batch_size=x.size(0), qlen=x.size(1), klen=x.size(1)
-# )
-
-# if coords is None:
-# coords_rc = None
-# else:
-# coords_rc = coords[0] if isinstance(coords, list) else coords
-# if coords_rc.dim() != 2 or coords_rc.size(-1) != 2:
-# raise ValueError("coords must be [L, 2] tensor")
-
-# new_coords = coords_rc
-# new_mask = encoder_padding_mask
-# num_layers = self.num_layers
-# if incremental_state is None:
-# incremental_state = [None] * num_layers
-
-# # incremental_state is not None during inference if we use the bidirectional encoder as a generator as in s2s-ft (https://arxiv.org/abs/2110.13640)
-# l_aux = []
-# for idx, layer in enumerate(self.layers):
-# rel_pos = None
-# if new_coords is not None and self.use_rel_pos_2d:
-# rel_pos = self._coords_to_relpos(B, new_coords)
-
-# x, l_aux_i = layer(
-# x,
-# encoder_padding_mask=encoder_padding_mask if incremental_state is None else None,
-# attn_mask=attn_mask,
-# rel_pos=rel_pos_bias if not self.use_rel_pos_2d else rel_pos,
-# multiway_split_position=multiway_split_position,
-# incremental_state=incremental_state[idx] if incremental_state is not None else None,
-# )
-
-# if return_all_hiddens:
-# assert encoder_states is not None
-# encoder_states.append(x)
-# l_aux.append(l_aux_i)
-
-# # patch merge after activation, do this except the final layer
-# if idx < num_layers - 1 and new_coords is not None:
-# x, new_coords, new_mask, _aux = self.patch_merge(x, self._coords_to_rowcol(new_coords), new_mask)
-
-# if self.layer_norm is not None:
-# x = self.layer_norm(x)
-
-# if not features_only and self.output_projection is not None:
-# x = self.output_projection(x)
-
-
-# return {
-# "encoder_out": x,
-# "encoder_embedding": encoder_embedding,
-# "encoder_padding_mask": encoder_padding_mask,
-# "encoder_states": encoder_states,
-# "l_aux": l_aux,
-# }
-
-# @staticmethod
-# def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
-# with torch.no_grad():
-# x = coords_xy[:, 0]
-# y = coords_xy[:, 1]
-# x_sorted = x.unique(sorted=True)
-# y_sorted = y.unique(sorted=True)
-# col = torch.searchsorted(x_sorted, x)
-# row = torch.searchsorted(y_sorted, y)
-# return torch.stack([row, col], dim=-1)
-
-# def _coords_to_relpos(self, B: int, coords_xy: torch.Tensor):
-# rc = self._coords_to_rowcol(coords_xy)
-# return self._compute_relpos(B, rc)
-
-
-
-
-# class LongNetWithPatchMerging(nn.Module):
-# def __init__(
-# self,
-# backbone_encoder: nn.Module,
-# keep_dim_after_merge: bool = True,
-# merge_size: int = 2,
-# use_rel_pos_2d: bool = True,
-# rel_pos_num_buckets: int = 32,
-# rel_pos_max_distance: int = 512,
-# ):
-# super().__init__()
-# self.backbone = backbone_encoder
-# self.layers = backbone_encoder.layers
-# self.num_layers = getattr(backbone_encoder, "num_layers", len(self.layers))
-# self.embed_dim = getattr(backbone_encoder, "embed_dim", None) or backbone_encoder.args.encoder_embed_dim
-# self.keep_dim_after_merge = keep_dim_after_merge
-# self.merge_size = merge_size
-
-# self.patch_merge = SparsePatchMerging(self.embed_dim, keep_dim=keep_dim_after_merge, merge_size=merge_size)
-# self.merge_proj_back = None
-# if not keep_dim_after_merge:
-# self.merge_proj_back = nn.Linear(self.patch_merge.out_dim, self.embed_dim)
-
-# self.use_rel_pos_2d = use_rel_pos_2d
-# if use_rel_pos_2d:
-# n_heads = backbone_encoder.args.encoder_attention_heads
-# self.relpos2d = RelativePositionBias2D(
-# n_heads=n_heads,
-# num_buckets=rel_pos_num_buckets,
-# max_distance=rel_pos_max_distance,
-# bidirectional=True,
-# )
-# else:
-# self.relpos2d = None
-
-# for name in ["embed_tokens", "embed_positions", "layer_norm", "output_projection", "args"]:
-# if hasattr(backbone_encoder, name):
-# setattr(self, name, getattr(backbone_encoder, name))
-
-# def forward_embedding(self, *args, **kwargs):
-# return self.backbone.forward_embedding(*args, **kwargs)
-
-# def _compute_relpos(self, B: int, coords_rc: torch.Tensor):
-# if not self.use_rel_pos_2d:
-# return None
-# return self.relpos2d(B, coords_rc, coords_rc)
-
-# def forward(
-# self,
-# src_tokens=None,
-# encoder_padding_mask: Optional[torch.Tensor] = None,
-# attn_mask: Optional[torch.Tensor] = None,
-# return_all_hiddens: bool = False,
-# token_embeddings: Optional[torch.Tensor] = None,
-# multiway_split_position=None,
-# features_only: bool = False,
-# incremental_state: Optional[List[Dict]] = None,
-# positions=None,
-# coords: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
-# **kwargs,
-# ):
-# assert src_tokens is not None or token_embeddings is not None
-
-# if encoder_padding_mask is None:
-# if token_embeddings is not None:
-# encoder_padding_mask = torch.zeros(
-# [token_embeddings.size(0), token_embeddings.size(1)],
-# device=token_embeddings.device,
-# dtype=torch.bool,
-# )
-# else:
-# encoder_padding_mask = torch.zeros(
-# [src_tokens.size(0), src_tokens.size(1)],
-# device=src_tokens.device,
-# dtype=torch.bool,
-# )
-
-# x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings, positions)
-# x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
-
-# B, L, C = x.shape
-# encoder_states = []
-# l_aux = []
-
-# if return_all_hiddens:
-# encoder_states.append(x)
-
-# if coords is None:
-# coords_rc = None
-# else:
-# coords_rc = coords[0] if isinstance(coords, list) else coords
-# if coords_rc.dim() != 2 or coords_rc.size(-1) != 2:
-# raise ValueError("coords must be [L, 2] tensor")
-
-# new_coords = coords_rc
-# new_mask = encoder_padding_mask
-# num_layers = self.num_layers
-# if incremental_state is None:
-# incremental_state = [None] * num_layers
-
-# for idx, layer in enumerate(self.layers):
-# rel_pos = None
-# if new_coords is not None and self.use_rel_pos_2d:
-# rel_pos = self._coords_to_relpos(B, new_coords)
-
-# x, l_aux_i = layer(
-# x,
-# new_mask,
-# attn_mask=attn_mask,
-# rel_pos=rel_pos,
-# multiway_split_position=multiway_split_position,
-# incremental_state=incremental_state[idx],
-# )
-# if return_all_hiddens:
-# encoder_states.append(x)
-# l_aux.append(l_aux_i)
-
-# if idx < num_layers - 1 and new_coords is not None:
-# x, new_coords, new_mask, _aux = self.patch_merge(x, self._coords_to_rowcol(new_coords), new_mask)
-# if self.merge_proj_back is not None:
-# x = self.merge_proj_back(x)
-
-# if hasattr(self, "layer_norm") and self.layer_norm is not None:
-# x = self.layer_norm(x)
-# if not features_only and hasattr(self, "output_projection") and self.output_projection is not None:
-# x = self.output_projection(x)
-
-# return {
-# "encoder_out": x,
-# "encoder_embedding": encoder_embedding,
-# "encoder_padding_mask": new_mask,
-# "encoder_states": encoder_states,
-# "l_aux": l_aux,
-# "coords": new_coords,
-# }
-
-# @staticmethod
-# def _coords_to_rowcol(coords_xy: torch.Tensor) -> torch.Tensor:
-# x = coords_xy[:, 0]
-# y = coords_xy[:, 1]
-# x_sorted = x.unique(sorted=True)
-# y_sorted = y.unique(sorted=True)
-# col = torch.searchsorted(x_sorted, x)
-# row = torch.searchsorted(y_sorted, y)
-# return torch.stack([row, col], dim=-1)
-
-# def _coords_to_relpos(self, B: int, coords_xy: torch.Tensor):
-# rc = self._coords_to_rowcol(coords_xy)
-# return self._compute_relpos(B, rc)
diff --git a/code/xtuner/model/torchscale/model/__init__.py b/code/xtuner/model/torchscale/model/__init__.py
deleted file mode 100644
index 3ae31e2507e8759f2ac7f85e517288f536c04ac3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) 2022 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
diff --git a/code/xtuner/model/torchscale/model/__pycache__/LongNet.cpython-311.pyc b/code/xtuner/model/torchscale/model/__pycache__/LongNet.cpython-311.pyc
deleted file mode 100644
index 61c232755962562bc7d4645657baad957b4fc26e..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/__pycache__/LongNet.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/__pycache__/LongNetConfig.cpython-311.pyc b/code/xtuner/model/torchscale/model/__pycache__/LongNetConfig.cpython-311.pyc
deleted file mode 100644
index 2e9a11fb4f030f17f6d5e6f9b789a2e0ad87d2b2..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/__pycache__/LongNetConfig.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/__pycache__/__init__.cpython-311.pyc b/code/xtuner/model/torchscale/model/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 0a7cbc267d3b04fdf5c4718180ccaa28ab4fff80..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/__pycache__/pos_embed.cpython-311.pyc b/code/xtuner/model/torchscale/model/__pycache__/pos_embed.cpython-311.pyc
deleted file mode 100644
index a9f8ef22bfec85f712cdeceb66605a4285826e69..0000000000000000000000000000000000000000
Binary files a/code/xtuner/model/torchscale/model/__pycache__/pos_embed.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/model/torchscale/model/create_longnet_for_training.py b/code/xtuner/model/torchscale/model/create_longnet_for_training.py
deleted file mode 100644
index e7e04bb1a819985792a01bac9eac27cea7286648..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/create_longnet_for_training.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-import huggingface_hub
-from xtuner.registry import BUILDER
-
-try:
- from safetensors.torch import load_file as st_load_file, save_file as st_save_file
- _HAS_ST = True
-except Exception:
- _HAS_ST = False
-
-def is_dist():
- return torch.distributed.is_available() and torch.distributed.is_initialized()
-
-def is_rank0():
- return (not is_dist()) or torch.distributed.get_rank() == 0
-
-
-@BUILDER.register_module()
-def create_longvit_model_fast(
- model_arch: str,
- in_chans: int,
- *,
- hub_id: str | None = None, # e.g. "prov-gigapath/prov-gigapath"
- local_ckpt: str | None = None, # e.g. "/path/slide_encoder.(safetensors|pth)"
- filename: str = "slide_encoder.pth", # HF filename
- local_dir: str | None = None,
- global_pool=False,
- tile_size=224,
- max_wsi_size=420096,
- token_norm = True,
- dtype=None,
-):
- assert (hub_id is None) ^ (local_ckpt is None), \
- "Provide exactly one of {hub_id, local_ckpt}"
-
- from xtuner.model.utils import LoadWoInit
- from xtuner.model.torchscale.model.LongNetVit import create_model_original as _create
- with LoadWoInit():
- model = _create(
- pretrained="", model_arch=model_arch,
- in_chans=in_chans,
- global_pool=global_pool,
- tile_size=tile_size,
- max_wsi_size=max_wsi_size,
- token_norm = token_norm
- )
-
- # --- resolve checkpoint path ---
- if local_ckpt is not None:
- ckpt_path = local_ckpt
- if not os.path.isfile(ckpt_path):
- raise FileNotFoundError(f"[LongNet] local_ckpt not found: {ckpt_path}")
- else:
- # hub path
- if is_rank0():
- huggingface_hub.hf_hub_download(
- hub_id, filename=filename,
- local_dir=local_dir or os.path.join(os.path.expanduser("~"), ".cache"),
- force_download=False, local_files_only=False
- )
- if is_dist():
- torch.distributed.barrier()
- ckpt_path = os.path.join(local_dir or os.path.join(os.path.expanduser("~"), ".cache"), filename)
- if not os.path.isfile(ckpt_path):
- raise FileNotFoundError(f"[LongNet] HF file not present after download: {ckpt_path}")
-
- # --- prefer safetensors if provided ---
- if ckpt_path.endswith(".safetensors"):
- from safetensors.torch import load_file as st_load
- sd = st_load(ckpt_path, device="cpu")
- else:
- try:
- obj = torch.load(ckpt_path, map_location="cpu", weights_only=True)
- except TypeError:
- obj = torch.load(ckpt_path, map_location="cpu")
- sd = obj["model"] if isinstance(obj, dict) and "model" in obj else obj
-
- # --- void-copy first layer; filter mismatches ---
- sd.pop("patch_embed.proj.weight", None)
- sd.pop("patch_embed.proj.bias", None)
- msd = model.state_dict()
- sd = {k: v for k, v in sd.items() if k in msd and msd[k].shape == v.shape}
- model.load_state_dict(sd, strict=False)
-
- # --- reinit first layer + sync dtype/device ---
- m = model.patch_embed.proj
- torch.nn.init.xavier_uniform_(m.weight);
- if m.bias is not None: torch.nn.init.zeros_(m.bias)
- # ref = next(p for p in model.parameters() if p.is_floating_point())
- # m.to(device=ref.device, dtype=ref.dtype)
-
- if dtype is not None:
- model = model.to(dtype=dtype)
-
- # if is_rank0():
- print(f"[LongNet] loaded from {'hub:'+hub_id if hub_id else ckpt_path}")
- return model
\ No newline at end of file
diff --git a/code/xtuner/model/torchscale/model/pos_embed.py b/code/xtuner/model/torchscale/model/pos_embed.py
deleted file mode 100644
index 94ac16be9b86ca05ddac3c8bb20ba06c0d3d2689..0000000000000000000000000000000000000000
--- a/code/xtuner/model/torchscale/model/pos_embed.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
-# DeiT: https://github.com/facebookresearch/deit
-# MAE: https://github.com/facebookresearch/mae
-# --------------------------------------------------------
-#
-# Portions Copyright Prov-GigaPath
-# Original File: https://github.com/facebookresearch/mae
-# --------------------------------------------------------
-# Position embedding utils
-# --------------------------------------------------------
-
-import numpy as np
-
-import torch
-
-
-# --------------------------------------------------------
-# 2D sine-cosine position embedding
-# References:
-# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
-# MoCo v3: https://github.com/facebookresearch/moco-v3
-# --------------------------------------------------------
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
- """
- grid_size: int of the grid height and width
- return:
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- """
- grid_h = np.arange(grid_size, dtype=np.float32)
- grid_w = np.arange(grid_size, dtype=np.float32)
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
- grid = np.stack(grid, axis=0)
-
- grid = grid.reshape([2, 1, grid_size, grid_size])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if cls_token:
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
- return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
- """
- embed_dim: output dimension for each position
- pos: a list of positions to be encoded: size (M,)
- out: (M, D)
- """
- assert embed_dim % 2 == 0
- omega = np.arange(embed_dim // 2, dtype=float)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = np.sin(out) # (M, D/2)
- emb_cos = np.cos(out) # (M, D/2)
-
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
- return emb
-
-
-# --------------------------------------------------------
-# Interpolate position embeddings for high-resolution
-# References:
-# DeiT: https://github.com/facebookresearch/deit
-# --------------------------------------------------------
-def interpolate_pos_embed(model, checkpoint_model):
- if "pos_embed" in checkpoint_model:
- pos_embed_checkpoint = checkpoint_model["pos_embed"]
- embedding_size = pos_embed_checkpoint.shape[-1]
- num_patches = model.patch_embed.num_patches
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
- # height (== width) for the checkpoint position embedding
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
- # height (== width) for the new position embedding
- new_size = int(num_patches**0.5)
- # class_token and dist_token are kept unchanged
- if orig_size != new_size:
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
- # only the position tokens are interpolated
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
- pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
- checkpoint_model["pos_embed"] = new_pos_embed
diff --git a/code/xtuner/model/transformers_models/__init__.py b/code/xtuner/model/transformers_models/__init__.py
deleted file mode 100644
index 71f7ea1d42e34a3fa6b4239b86a468c2e7727b14..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .deepseek_v2 import (DeepseekTokenizerFast, DeepseekV2Config,
- DeepseekV2ForCausalLM, DeepseekV2Model)
-from .mixtral import MixtralConfig, MixtralForCausalLM, MixtralModel
-
-__all__ = [
- 'DeepseekTokenizerFast', 'DeepseekV2Config', 'DeepseekV2ForCausalLM',
- 'DeepseekV2Model', 'MixtralConfig', 'MixtralForCausalLM', 'MixtralModel'
-]
diff --git a/code/xtuner/model/transformers_models/deepseek_v2/__init__.py b/code/xtuner/model/transformers_models/deepseek_v2/__init__.py
deleted file mode 100644
index 6a74b483ca374f0b50c9e3a5e536e54aa671cca4..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/deepseek_v2/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .configuration_deepseek import DeepseekV2Config
-from .modeling_deepseek import DeepseekV2ForCausalLM, DeepseekV2Model
-from .tokenization_deepseek_fast import DeepseekTokenizerFast
-
-__all__ = [
- 'DeepseekV2ForCausalLM', 'DeepseekV2Model', 'DeepseekV2Config',
- 'DeepseekTokenizerFast'
-]
diff --git a/code/xtuner/model/transformers_models/deepseek_v2/configuration_deepseek.py b/code/xtuner/model/transformers_models/deepseek_v2/configuration_deepseek.py
deleted file mode 100644
index daaddcf4922fcfe3617040da2717ee912a10f123..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/deepseek_v2/configuration_deepseek.py
+++ /dev/null
@@ -1,219 +0,0 @@
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
-
-
-# Compared to the original version, two parameters, `moe_implementation` and
-# `expert_in_one_shard`, have been added.
-class DeepseekV2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to that of the DeepSeek-V2.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 102400):
- Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`DeepseekV2Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 11008):
- Dimension of the MLP representations.
- moe_intermediate_size (`int`, *optional*, defaults to 1407):
- Dimension of the MoE representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer decoder.
- n_shared_experts (`int`, *optional*, defaults to None):
- Number of shared experts, None means dense model.
- n_routed_experts (`int`, *optional*, defaults to None):
- Number of routed experts, None means dense model.
- routed_scaling_factor (`float`, *optional*, defaults to 1.0):
- Scaling factor or routed experts.
- topk_method (`str`, *optional*, defaults to `gready`):
- Topk method used in routed gate.
- n_group (`int`, *optional*, defaults to None):
- Number of groups for routed experts.
- topk_group (`int`, *optional*, defaults to None):
- Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
- num_experts_per_tok (`int`, *optional*, defaults to None):
- Number of selected experts, None means dense model.
- moe_layer_freq (`int`, *optional*, defaults to 1):
- The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
- first_k_dense_replace (`int`, *optional*, defaults to 0):
- Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
- \--k dense layers--/
- norm_topk_prob (`bool`, *optional*, defaults to False):
- Whether to normalize the weights of the routed experts.
- scoring_func (`str`, *optional*, defaults to 'softmax'):
- Method of computing expert weights.
- aux_loss_alpha (`float`, *optional*, defaults to 0.001):
- Auxiliary loss weight coefficient.
- seq_aux = (`bool`, *optional*, defaults to True):
- Whether to compute the auxiliary loss for each individual sample.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 2048):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*):
- Padding token id.
- bos_token_id (`int`, *optional*, defaults to 1):
- Beginning of stream token id.
- eos_token_id (`int`, *optional*, defaults to 2):
- End of stream token id.
- pretraining_tp (`int`, *optional*, defaults to 1):
- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232).
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- rope_scaling (`Dict`, *optional*):
- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
- `max_position_embeddings` to the expected new maximum.
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- moe_implementation (`str`, *optional*, defaults to 'origin'):
- The implementation of the moe blocks. 'origin' or 'shard'.
- expert_in_one_shard (`int`, *optional*, defaults to None):
- How many expert models are integrated into a shard. It is used only
- when `moe_implementation` == 'shard'
-
- ```python
- >>> from transformers import DeepseekV2Model, DeepseekV2Config
-
- >>> # Initializing a Deepseek-V2 style configuration
- >>> configuration = DeepseekV2Config()
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = 'deepseek_v2'
- keys_to_ignore_at_inference = ['past_key_values']
-
- def __init__(
- self,
- vocab_size=102400,
- hidden_size=4096,
- intermediate_size=11008,
- moe_intermediate_size=1407,
- num_hidden_layers=30,
- num_attention_heads=32,
- num_key_value_heads=32,
- n_shared_experts=None,
- n_routed_experts=None,
- ep_size=1,
- routed_scaling_factor=1.0,
- kv_lora_rank=512,
- q_lora_rank=1536,
- qk_rope_head_dim=64,
- v_head_dim=128,
- qk_nope_head_dim=128,
- topk_method='gready',
- n_group=None,
- topk_group=None,
- num_experts_per_tok=None,
- moe_layer_freq=1,
- first_k_dense_replace=0,
- norm_topk_prob=False,
- scoring_func='softmax',
- aux_loss_alpha=0.001,
- seq_aux=True,
- hidden_act='silu',
- max_position_embeddings=2048,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=None,
- bos_token_id=100000,
- eos_token_id=100001,
- pretraining_tp=1,
- tie_word_embeddings=False,
- rope_theta=10000.0,
- rope_scaling=None,
- attention_bias=False,
- attention_dropout=0.0,
- moe_implementation='origin',
- expert_in_one_shard=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.moe_intermediate_size = moe_intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.n_shared_experts = n_shared_experts
- self.n_routed_experts = n_routed_experts
- self.ep_size = ep_size
- self.routed_scaling_factor = routed_scaling_factor
- self.kv_lora_rank = kv_lora_rank
- self.q_lora_rank = q_lora_rank
- self.qk_rope_head_dim = qk_rope_head_dim
- self.v_head_dim = v_head_dim
- self.qk_nope_head_dim = qk_nope_head_dim
- self.topk_method = topk_method
- self.n_group = n_group
- self.topk_group = topk_group
- self.num_experts_per_tok = num_experts_per_tok
- self.moe_layer_freq = moe_layer_freq
- self.first_k_dense_replace = first_k_dense_replace
- self.norm_topk_prob = norm_topk_prob
- self.scoring_func = scoring_func
- self.aux_loss_alpha = aux_loss_alpha
- self.seq_aux = seq_aux
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.pretraining_tp = pretraining_tp
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.moe_implementation = moe_implementation
- self.expert_in_one_shard = expert_in_one_shard
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
diff --git a/code/xtuner/model/transformers_models/deepseek_v2/modeling_deepseek.py b/code/xtuner/model/transformers_models/deepseek_v2/modeling_deepseek.py
deleted file mode 100644
index f58dd466fa7a4b754df2b5e7b3da8911985d182d..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/deepseek_v2/modeling_deepseek.py
+++ /dev/null
@@ -1,2037 +0,0 @@
-# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch DeepSeek model."""
-import copy
-import math
-import os
-import types
-import warnings
-from typing import List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache
-from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_attn_mask_utils import (
- AttentionMaskConverter, _prepare_4d_attention_mask,
- _prepare_4d_causal_attention_mask,
- _prepare_4d_causal_attention_mask_for_sdpa)
-from transformers.modeling_outputs import (BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS,
- is_torch_greater_or_equal_than_1_13)
-from transformers.utils import (add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- is_flash_attn_greater_or_equal_2_10, logging,
- replace_return_docstrings)
-from transformers.utils.import_utils import is_torch_fx_available
-
-from xtuner.utils import load_state_dict_into_model
-from .configuration_deepseek import DeepseekV2Config
-
-if is_flash_attn_2_available():
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input # noqa
- from flash_attn.bert_padding import index_first_axis, unpad_input
-
-# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
-# It means that the function will not be traced through and simply appear as a node in the graph.
-if is_torch_fx_available():
- if not is_torch_greater_or_equal_than_1_13:
- import torch.fx
-
- _prepare_4d_causal_attention_mask = torch.fx.wrap(
- _prepare_4d_causal_attention_mask)
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = 'DeepseekV2Config'
-
-
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-class DeepseekV2RMSNorm(nn.Module):
-
- def __init__(self, hidden_size, eps=1e-6):
- """DeepseekV2RMSNorm is equivalent to T5LayerNorm."""
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance +
- self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
-
-
-class DeepseekV2RotaryEmbedding(nn.Module):
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (
- self.base
- **(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
- self.register_buffer('inv_freq', inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings,
- device=self.inv_freq.device,
- dtype=torch.get_default_dtype(),
- )
- self.max_seq_len_cached = None
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
-
- freqs = torch.outer(t, self.inv_freq.to(t.device))
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- 'cos_cached', emb.cos().to(dtype), persistent=False)
- self.register_buffer(
- 'sin_cached', emb.sin().to(dtype), persistent=False)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(
- seq_len=seq_len, device=x.device, dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
-class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
- """DeepseekV2RotaryEmbedding extended with linear scaling.
-
- Credits to the Reddit user /u/kaiokendev
- """
-
- def __init__(
- self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- ):
- self.scaling_factor = scaling_factor
- super().__init__(dim, max_position_embeddings, base, device)
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
- t = t / self.scaling_factor
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- 'cos_cached', emb.cos().to(dtype), persistent=False)
- self.register_buffer(
- 'sin_cached', emb.sin().to(dtype), persistent=False)
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
-class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
- """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling.
-
- Credits to the Reddit users /u/bloc97 and /u/emozilla
- """
-
- def __init__(
- self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- ):
- self.scaling_factor = scaling_factor
- super().__init__(dim, max_position_embeddings, base, device)
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
-
- if seq_len > self.max_position_embeddings:
- base = self.base * ((self.scaling_factor * seq_len /
- self.max_position_embeddings) -
- (self.scaling_factor - 1))**(
- self.dim / (self.dim - 2))
- inv_freq = 1.0 / (
- base
- **(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
- self.register_buffer('inv_freq', inv_freq, persistent=False)
-
- t = torch.arange(
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- 'cos_cached', emb.cos().to(dtype), persistent=False)
- self.register_buffer(
- 'sin_cached', emb.sin().to(dtype), persistent=False)
-
-
-# Inverse dim formula to find dim based on number of rotations
-def yarn_find_correction_dim(num_rotations,
- dim,
- base=10000,
- max_position_embeddings=2048):
- return (dim * math.log(max_position_embeddings /
- (num_rotations * 2 * math.pi))) / (2 *
- math.log(base))
-
-
-# Find dim range bounds based on rotations
-def yarn_find_correction_range(low_rot,
- high_rot,
- dim,
- base=10000,
- max_position_embeddings=2048):
- low = math.floor(
- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
- high = math.ceil(
- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
-
-
-def yarn_get_mscale(scale=1, mscale=1):
- if scale <= 1:
- return 1.0
- return 0.1 * mscale * math.log(scale) + 1.0
-
-
-def yarn_linear_ramp_mask(min, max, dim):
- if min == max:
- max += 0.001 # Prevent singularity
-
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
- ramp_func = torch.clamp(linear_func, 0, 1)
- return ramp_func
-
-
-class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
-
- def __init__(
- self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- original_max_position_embeddings=4096,
- beta_fast=32,
- beta_slow=1,
- mscale=1,
- mscale_all_dim=0,
- ):
- self.scaling_factor = scaling_factor
- self.original_max_position_embeddings = original_max_position_embeddings
- self.beta_fast = beta_fast
- self.beta_slow = beta_slow
- self.mscale = mscale
- self.mscale_all_dim = mscale_all_dim
- super().__init__(dim, max_position_embeddings, base, device)
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- dim = self.dim
-
- freq_extra = 1.0 / (
- self.base**(torch.arange(
- 0, dim, 2, dtype=torch.float32, device=device) / dim))
- freq_inter = 1.0 / (
- self.scaling_factor * self.base**(torch.arange(
- 0, dim, 2, dtype=torch.float32, device=device) / dim))
-
- low, high = yarn_find_correction_range(
- self.beta_fast,
- self.beta_slow,
- dim,
- self.base,
- self.original_max_position_embeddings,
- )
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
- device=device, dtype=torch.float32)
- inv_freq = freq_inter * (1 -
- inv_freq_mask) + freq_extra * inv_freq_mask
- self.register_buffer('inv_freq', inv_freq, persistent=False)
-
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
-
- freqs = torch.outer(t, inv_freq)
-
- _mscale = float(
- yarn_get_mscale(self.scaling_factor, self.mscale) /
- yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
-
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- 'cos_cached', (emb.cos() * _mscale).to(dtype), persistent=False)
- self.register_buffer(
- 'sin_cached', (emb.sin() * _mscale).to(dtype), persistent=False)
-
-
-# Copied from transformers.models.llama.modeling_llama.rotate_half
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`):
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
- used to pass offsetted position ids when working with a KV-cache.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
-
- b, h, s, d = q.shape
- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
-
- b, h, s, d = k.shape
- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
-
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class DeepseekV2MLP(nn.Module):
-
- def __init__(self, config, hidden_size=None, intermediate_size=None):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
- self.intermediate_size = (
- config.intermediate_size
- if intermediate_size is None else intermediate_size)
-
- self.gate_proj = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(
- self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(
- self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.down_proj(
- self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
-
-
-class MoEGate(nn.Module):
-
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.top_k = config.num_experts_per_tok
- self.n_routed_experts = config.n_routed_experts
- self.routed_scaling_factor = config.routed_scaling_factor
- self.scoring_func = config.scoring_func
- self.alpha = config.aux_loss_alpha
- self.seq_aux = config.seq_aux
- self.topk_method = config.topk_method
- self.n_group = config.n_group
- self.topk_group = config.topk_group
-
- # topk selection algorithm
- self.norm_topk_prob = config.norm_topk_prob
- self.gating_dim = config.hidden_size
- self.weight = nn.Parameter(
- torch.empty((self.n_routed_experts, self.gating_dim)))
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- import torch.nn.init as init
-
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
-
- def forward(self, hidden_states):
- bsz, seq_len, h = hidden_states.shape
- ### compute gating score
- hidden_states = hidden_states.view(-1, h)
- logits = F.linear(
- hidden_states.type(torch.float32), self.weight.type(torch.float32),
- None)
- if self.scoring_func == 'softmax':
- scores = logits.softmax(dim=-1, dtype=torch.float32)
- else:
- raise NotImplementedError(
- f'insupportable scoring function for MoE gating: {self.scoring_func}'
- )
-
- ### select top-k experts
- # fix official typos
- if self.topk_method in ('gready', 'greedy'):
- topk_weight, topk_idx = torch.topk(
- scores, k=self.top_k, dim=-1, sorted=False)
- elif self.topk_method == 'group_limited_greedy':
- group_scores = (scores.view(bsz * seq_len, self.n_group,
- -1).max(dim=-1).values) # [n, n_group]
- group_idx = torch.topk(
- group_scores, k=self.topk_group, dim=-1,
- sorted=False)[1] # [n, top_k_group]
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
- score_mask = (group_mask.unsqueeze(-1).expand(
- bsz * seq_len, self.n_group,
- self.n_routed_experts // self.n_group).reshape(
- bsz * seq_len, -1)) # [n, e]
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
- topk_weight, topk_idx = torch.topk(
- tmp_scores, k=self.top_k, dim=-1, sorted=False)
-
- ### norm gate to sum 1
- if self.top_k > 1 and self.norm_topk_prob:
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
- topk_weight = topk_weight / denominator
- else:
- topk_weight = topk_weight * self.routed_scaling_factor
- ### expert-level computation auxiliary loss
- if self.training and self.alpha > 0.0:
- scores_for_aux = scores
- aux_topk = self.top_k
- # always compute aux loss based on the naive greedy topk method
- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
- if self.seq_aux:
- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
- ce = torch.zeros(
- bsz, self.n_routed_experts, device=hidden_states.device)
- ce.scatter_add_(
- 1,
- topk_idx_for_aux_loss,
- torch.ones(
- bsz, seq_len * aux_topk, device=hidden_states.device),
- ).div_(seq_len * aux_topk / self.n_routed_experts)
- aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
- dim=1).mean() * self.alpha
- else:
- mask_ce = F.one_hot(
- topk_idx_for_aux_loss.view(-1),
- num_classes=self.n_routed_experts)
- ce = mask_ce.float().mean(0)
- Pi = scores_for_aux.mean(0)
- fi = ce * self.n_routed_experts
- aux_loss = (Pi * fi).sum() * self.alpha
- else:
- aux_loss = None
- return topk_idx, topk_weight, aux_loss
-
-
-class AddAuxiliaryLoss(torch.autograd.Function):
- """The trick function of adding auxiliary (aux) loss, which includes the
- gradient of the aux loss during backpropagation."""
-
- @staticmethod
- def forward(ctx, x, loss):
- assert loss.numel() == 1
- ctx.dtype = loss.dtype
- ctx.required_aux_loss = loss.requires_grad
- return x
-
- @staticmethod
- def backward(ctx, grad_output):
- grad_loss = None
- if ctx.required_aux_loss:
- grad_loss = torch.ones(
- 1, dtype=ctx.dtype, device=grad_output.device)
- return grad_output, grad_loss
-
-
-class ExpertShard(nn.Module):
-
- def __init__(self, config, shard_idx, expert_in_one_shard=10):
- super().__init__()
- hidden_dim = config.hidden_size
- ffn_dim = config.moe_intermediate_size
- self.w1w3 = nn.Parameter(
- torch.empty(expert_in_one_shard, ffn_dim * 2, hidden_dim))
- self.w2 = nn.Parameter(
- torch.empty(expert_in_one_shard, hidden_dim, ffn_dim))
-
- self.act = nn.SiLU()
- self.expert_in_one_shard = expert_in_one_shard
- self.shard_idx = shard_idx
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- # Different from nn.Linear module, weights of self.w1w3 and self.w2
- # can not be initialized by DeepseekV2PreTrainedModel._init_weights method
- self.w1w3.data.normal_(0, 0.02)
- self.w2.data.normal_(0, 0.02)
-
- def expert_forward(self, current_state, expert_idx):
- w1w3 = self.w1w3[expert_idx]
- w2 = self.w2[expert_idx]
- gate_up_out = torch.matmul(current_state, w1w3.T)
- gate_out, up_out = gate_up_out.chunk(2, dim=-1)
- gate_out = self.act(gate_out)
- out = gate_out * up_out
- out = torch.matmul(out, w2.T)
- return out
-
- def forward(self, hidden_states, flat_topk_idx, y):
- for i in range(self.expert_in_one_shard):
- expert_idx = i + self.expert_in_one_shard * self.shard_idx
- y[flat_topk_idx == expert_idx] = self.expert_forward(
- hidden_states[flat_topk_idx == expert_idx], i)
- return y
-
-
-class DeepseekV2MoEShard(nn.Module):
- """A mixed expert module containing shared experts."""
-
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.num_experts_per_tok = config.num_experts_per_tok
-
- if hasattr(config, 'ep_size') and config.ep_size > 1:
- raise NotImplementedError
- else:
- self.ep_size = 1
- self.experts_per_rank = config.n_routed_experts
- self.ep_rank = 0
- self.n_routed_experts = config.n_routed_experts
-
- expert_in_one_shard = config.expert_in_one_shard
- assert config.n_routed_experts % expert_in_one_shard == 0, \
- ('n_routed_experts should be divisible by expert_in_one_shard, but got '
- f'n_routed_experts = {config.n_routed_experts} and expert_in_one_shard = {expert_in_one_shard}')
-
- self.shard_num = config.n_routed_experts // expert_in_one_shard
- self.expert_in_one_shard = expert_in_one_shard
- self.experts = nn.ModuleList([
- ExpertShard(config, i, self.expert_in_one_shard)
- for i in range(self.shard_num)
- ])
-
- self.gate = MoEGate(config)
- if config.n_shared_experts is not None:
- intermediate_size = config.moe_intermediate_size * config.n_shared_experts
- self.shared_experts = DeepseekV2MLP(
- config=config, intermediate_size=intermediate_size)
-
- def forward(self, hidden_states):
- if not self.training:
- raise NotImplementedError
-
- identity = hidden_states
- orig_shape = hidden_states.shape
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- flat_topk_idx = topk_idx.view(-1)
-
- hidden_states = hidden_states.repeat_interleave(
- self.num_experts_per_tok, dim=0)
- y = torch.empty_like(hidden_states)
- y_dtype = y.dtype
- for shard_index in range(self.shard_num):
- y = self.experts[shard_index](hidden_states, flat_topk_idx, y)
- y = ((y.view(*topk_weight.shape, -1) *
- topk_weight.unsqueeze(-1)).sum(dim=1)).type(y_dtype)
- y = y.view(*orig_shape)
- y = AddAuxiliaryLoss.apply(y, aux_loss)
-
- if self.config.n_shared_experts is not None:
- y = y + self.shared_experts(identity)
- return y
-
-
-class DeepseekV2MoE(nn.Module):
- """A mixed expert module containing shared experts."""
-
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.num_experts_per_tok = config.num_experts_per_tok
-
- if hasattr(config, 'ep_size') and config.ep_size > 1:
- assert config.ep_size == dist.get_world_size()
- self.ep_size = config.ep_size
- self.experts_per_rank = config.n_routed_experts // config.ep_size
- self.ep_rank = dist.get_rank()
- self.experts = nn.ModuleList([
- (DeepseekV2MLP(
- config, intermediate_size=config.moe_intermediate_size)
- if i >= self.ep_rank * self.experts_per_rank and i <
- (self.ep_rank + 1) * self.experts_per_rank else None)
- for i in range(config.n_routed_experts)
- ])
- else:
- self.ep_size = 1
- self.experts_per_rank = config.n_routed_experts
- self.ep_rank = 0
- self.experts = nn.ModuleList([
- DeepseekV2MLP(
- config, intermediate_size=config.moe_intermediate_size)
- for i in range(config.n_routed_experts)
- ])
- self.gate = MoEGate(config)
- if config.n_shared_experts is not None:
- intermediate_size = config.moe_intermediate_size * config.n_shared_experts
- self.shared_experts = DeepseekV2MLP(
- config=config, intermediate_size=intermediate_size)
-
- def forward(self, hidden_states):
- identity = hidden_states
- orig_shape = hidden_states.shape
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- flat_topk_idx = topk_idx.view(-1)
- if self.training:
- hidden_states = hidden_states.repeat_interleave(
- self.num_experts_per_tok, dim=0)
- y = torch.empty_like(hidden_states)
- y_dtype = y.dtype
- for i, expert in enumerate(self.experts):
- y[flat_topk_idx == i] = expert(
- hidden_states[flat_topk_idx == i])
- y = ((y.view(*topk_weight.shape, -1) *
- topk_weight.unsqueeze(-1)).sum(dim=1)).type(y_dtype)
- y = y.view(*orig_shape)
- y = AddAuxiliaryLoss.apply(y, aux_loss)
- else:
- y = self.moe_infer(hidden_states, topk_idx,
- topk_weight).view(*orig_shape)
- if self.config.n_shared_experts is not None:
- y = y + self.shared_experts(identity)
- return y
-
- @torch.no_grad()
- def moe_infer(self, x, topk_ids, topk_weight):
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
- cnts.scatter_(1, topk_ids, 1)
- tokens_per_expert = cnts.sum(dim=0)
- idxs = topk_ids.view(-1).argsort()
- sorted_tokens = x[idxs // topk_ids.shape[1]]
- sorted_tokens_shape = sorted_tokens.shape
- if self.ep_size > 1:
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
- -1).sum(dim=1)
- tokens_per_expert_group = tokens_per_expert.new_empty(
- tokens_per_expert.shape[0])
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
- output_splits = (
- tokens_per_expert_group.view(self.ep_size,
- -1).sum(1).cpu().numpy().tolist())
- gathered_tokens = sorted_tokens.new_empty(
- tokens_per_expert_group.sum(dim=0).cpu().item(),
- sorted_tokens.shape[1])
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
- dist.all_to_all(
- list(gathered_tokens.split(output_splits)),
- list(sorted_tokens.split(input_split_sizes)),
- )
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
- self.ep_size, self.experts_per_rank).sum(dim=0)
- gatherd_idxs = np.zeros(
- shape=(gathered_tokens.shape[0], ), dtype=np.int32)
- s = 0
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
- gatherd_idxs[s:s + k] = i % self.experts_per_rank
- s += k
- gatherd_idxs = gatherd_idxs.argsort()
- sorted_tokens = gathered_tokens[gatherd_idxs]
- tokens_per_expert = tokens_per_expert_post_gather
- tokens_per_expert = tokens_per_expert.cpu().numpy()
-
- outputs = []
- start_idx = 0
- for i, num_tokens in enumerate(tokens_per_expert):
- end_idx = start_idx + num_tokens
- if num_tokens == 0:
- continue
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
- expert_out = expert(tokens_for_this_expert)
- outputs.append(expert_out)
- start_idx = end_idx
-
- outs = torch.cat(
- outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
- if self.ep_size > 1:
- new_x = torch.empty_like(outs)
- new_x[gatherd_idxs] = outs
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
- dist.all_to_all(
- list(gathered_tokens.split(input_split_sizes)),
- list(new_x.split(output_splits)),
- )
- outs = gathered_tokens
-
- new_x = torch.empty_like(outs)
- new_x[idxs] = outs
- final_out = (
- new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(
- topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
- return final_out
-
-
-# Copied from transformers.models.llama.modeling_llama.repeat_kv
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
-class DeepseekV2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper."""
-
- def __init__(self,
- config: DeepseekV2Config,
- layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f'Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will '
- 'to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` '
- 'when creating this class.')
-
- self.attention_dropout = config.attention_dropout
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
-
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.q_lora_rank = config.q_lora_rank
- self.qk_rope_head_dim = config.qk_rope_head_dim
- self.kv_lora_rank = config.kv_lora_rank
- self.v_head_dim = config.v_head_dim
- self.qk_nope_head_dim = config.qk_nope_head_dim
- self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
-
- self.is_causal = True
-
- if self.q_lora_rank is None:
- self.q_proj = nn.Linear(
- self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
- else:
- self.q_a_proj = nn.Linear(
- self.hidden_size,
- config.q_lora_rank,
- bias=config.attention_bias)
- self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
- self.q_b_proj = nn.Linear(
- config.q_lora_rank,
- self.num_heads * self.q_head_dim,
- bias=False)
-
- self.kv_a_proj_with_mqa = nn.Linear(
- self.hidden_size,
- config.kv_lora_rank + config.qk_rope_head_dim,
- bias=config.attention_bias,
- )
- self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
- self.kv_b_proj = nn.Linear(
- config.kv_lora_rank,
- self.num_heads *
- (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
- bias=False,
- )
-
- self.o_proj = nn.Linear(
- self.num_heads * self.v_head_dim,
- self.hidden_size,
- bias=config.attention_bias,
- )
- self._init_rope()
-
- self.softmax_scale = self.q_head_dim**(-0.5)
- if self.config.rope_scaling is not None:
- mscale_all_dim = self.config.rope_scaling.get('mscale_all_dim', 0)
- scaling_factor = self.config.rope_scaling['factor']
- if mscale_all_dim:
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
- self.softmax_scale = self.softmax_scale * mscale * mscale
-
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = DeepseekV2RotaryEmbedding(
- self.qk_rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling['type']
- scaling_factor = self.config.rope_scaling['factor']
- if scaling_type == 'linear':
- self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
- self.qk_rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == 'dynamic':
- self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
- self.qk_rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == 'yarn':
- kwargs = {
- key: self.config.rope_scaling[key]
- for key in [
- 'original_max_position_embeddings',
- 'beta_fast',
- 'beta_slow',
- 'mscale',
- 'mscale_all_dim',
- ] if key in self.config.rope_scaling
- }
- self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
- self.qk_rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- **kwargs,
- )
- else:
- raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return (tensor.view(bsz, seq_len, self.num_heads,
- self.v_head_dim).transpose(1, 2).contiguous())
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
- )
- bsz, q_len, _ = hidden_states.size()
-
- if self.q_lora_rank is None:
- q = self.q_proj(hidden_states)
- else:
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
- q_nope, q_pe = torch.split(
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
-
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
- compressed_kv, k_pe = torch.split(
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
- kv = (
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
- bsz, q_len, self.num_heads,
- self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
-
- k_nope, value_states = torch.split(
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- kv_seq_len = value_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(
- kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
-
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
- self.q_head_dim)
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
- query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
-
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
- self.q_head_dim)
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
- key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- attn_weights = (
- torch.matmul(query_states, key_states.transpose(2, 3)) *
- self.softmax_scale)
-
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
- f' {attn_weights.size()}')
- assert attention_mask is not None
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
- )
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(
- attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_weights = nn.functional.dropout(
- attn_weights, p=self.attention_dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
- raise ValueError(
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is'
- f' {attn_output.size()}')
-
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.num_heads * self.v_head_dim)
-
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
-class DeepseekV2FlashAttention2(DeepseekV2Attention):
- """DeepseekV2 flash attention module.
-
- This module inherits from `DeepseekV2Attention` as the weights of the
- module stays untouched. The only required change would be on the forward
- pass where it needs to correctly call the public API of flash attention and
- deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- # DeepseekV2FlashAttention2 attention does not support output_attentions
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
- )
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- if self.q_lora_rank is None:
- q = self.q_proj(hidden_states)
- else:
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
- q_nope, q_pe = torch.split(
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
-
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
- compressed_kv, k_pe = torch.split(
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
- kv = (
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(
- bsz, q_len, self.num_heads,
- self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
-
- k_nope, value_states = torch.split(
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- kv_seq_len = value_states.shape[-2]
-
- kv_seq_len = value_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(
- kv_seq_len, self.layer_idx)
-
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
-
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len,
- self.q_head_dim)
- query_states[:, :, :, :self.qk_nope_head_dim] = q_nope
- query_states[:, :, :, self.qk_nope_head_dim:] = q_pe
-
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len,
- self.q_head_dim)
- key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
- key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
-
- if self.q_head_dim != self.v_head_dim:
- value_states = F.pad(value_states,
- [0, self.q_head_dim - self.v_head_dim])
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- dropout_rate = self.attention_dropout if self.training else 0.0
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (DeepseekV2RMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- # Handle the case where the model is quantized
- if hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- elif torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- else:
- target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype
-
- logger.warning_once(
- f'The input hidden states seems to be silently casted in float32, this might be related to'
- f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
- f' {target_dtype}.')
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate,
- softmax_scale=self.softmax_scale,
- )
- if self.q_head_dim != self.v_head_dim:
- attn_output = attn_output[:, :, :, :self.v_head_dim]
-
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads *
- self.v_head_dim).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
- def _flash_attention_forward(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None,
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
-
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`int`, *optional*):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- """
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- (
- query_states,
- key_states,
- value_states,
- indices_q,
- cu_seq_lens,
- max_seq_lens,
- ) = self._upad_input(query_states, key_states, value_states,
- attention_mask, query_length)
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
- query_length)
- else:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
- query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
- attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis(
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim),
- indices_k,
- )
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
- head_dim),
- indices_k,
- )
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
- head_dim),
- indices_k,
- )
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
- query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-
-ATTENTION_CLASSES = {
- 'eager': DeepseekV2Attention,
- 'flash_attention_2': DeepseekV2FlashAttention2,
-}
-
-
-class DeepseekV2DecoderLayer(nn.Module):
-
- def __init__(self, config: DeepseekV2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
- config=config, layer_idx=layer_idx)
-
- moe_implementation = config.moe_implementation
- if moe_implementation == 'origin':
- block = DeepseekV2MoE
- elif moe_implementation == 'shard':
- block = DeepseekV2MoEShard
- else:
- raise NotImplementedError
-
- self.mlp = (
- block(config) if
- (config.n_routed_experts is not None
- and layer_idx >= config.first_k_dense_replace and layer_idx %
- config.moe_layer_freq == 0) else DeepseekV2MLP(config))
- self.input_layernorm = DeepseekV2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = DeepseekV2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
- torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- if 'padding_mask' in kwargs:
- warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
- )
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states, )
-
- if output_attentions:
- outputs += (self_attn_weights, )
-
- if use_cache:
- outputs += (present_key_value, )
-
- return outputs
-
-
-def _load_pretrained_model(
- cls,
- model,
- state_dict,
- loaded_keys,
- resolved_archive_file,
- pretrained_model_name_or_path,
- ignore_mismatched_sizes=False,
- sharded_metadata=None,
- _fast_init=True,
- low_cpu_mem_usage=False,
- device_map=None,
- offload_folder=None,
- offload_state_dict=None,
- dtype=None,
- hf_quantizer=None,
- keep_in_fp32_modules=None,
- gguf_path=None,
-):
- if ((state_dict is not None) or (resolved_archive_file is None)
- or (low_cpu_mem_usage) or (device_map is not None)
- or (offload_folder is not None) or
- (not (offload_state_dict is None or offload_state_dict is False))
- or (hf_quantizer is not None) or
- (keep_in_fp32_modules is not None and len(keep_in_fp32_modules) > 0)
- or (gguf_path is not None)):
- raise NotImplementedError
-
- folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
- error_msgs = load_state_dict_into_model(model, folder)
- return model, [], [], [], None, error_msgs
-
-
-DeepseekV2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`DeepseekV2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- 'The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.',
- DeepseekV2_START_DOCSTRING,
-)
-class DeepseekV2PreTrainedModel(PreTrainedModel):
- config_class = DeepseekV2Config
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
- _no_split_modules = ['DeepseekV2DecoderLayer']
- _skip_keys_device_placement = 'past_key_values'
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
- moe_implementation = kwargs.get('moe_implementation', 'origin')
- if moe_implementation == 'origin':
- return super().from_pretrained(pretrained_model_name_or_path,
- *args, **kwargs)
-
- cls._load_pretrained_model = types.MethodType(_load_pretrained_model,
- cls)
- return super().from_pretrained(pretrained_model_name_or_path, *args,
- **kwargs)
-
-
-DeepseekV2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
-
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance;
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
-
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
- 'The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.',
- DeepseekV2_START_DOCSTRING,
-)
-class DeepseekV2Model(DeepseekV2PreTrainedModel):
- """Transformer decoder consisting of *config.num_hidden_layers* layers.
- Each layer is a [`DeepseekV2DecoderLayer`]
-
- Args:
- config: DeepseekV2Config
- """
-
- def __init__(self, config: DeepseekV2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
- self.padding_idx)
- self.layers = nn.ModuleList([
- DeepseekV2DecoderLayer(config, layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self._use_sdpa = config._attn_implementation == 'sdpa'
- self._use_flash_attention_2 = config._attn_implementation == 'flash_attention_2'
- self.norm = DeepseekV2RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = (
- output_attentions if output_attentions is not None else
- self.config.output_attentions)
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = (
- return_dict
- if return_dict is not None else self.config.use_return_dict)
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- 'You cannot specify both input_ids and inputs_embeds at the same time'
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape[:2]
- elif inputs_embeds is not None:
- batch_size, seq_length = inputs_embeds.shape[:2]
- else:
- raise ValueError(
- 'You have to specify either input_ids or inputs_embeds')
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers.'
- )
- use_cache = False
-
- past_key_values_length = 0
- if use_cache:
- use_legacy_cache = not isinstance(past_key_values, Cache)
- if use_legacy_cache:
- past_key_values = DynamicCache.from_legacy_cache(
- past_key_values)
- past_key_values_length = past_key_values.get_usable_length(
- seq_length)
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device,
- )
- position_ids = position_ids.unsqueeze(0)
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = (
- attention_mask if
- (attention_mask is not None and 0 in attention_mask) else None)
- elif self._use_sdpa and not output_attentions:
- # output_attentions=True can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- )
- else:
- # 4d mask is passed through the layers
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- )
-
- # embed positions
- hidden_states = inputs_embeds
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- attention_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[
- 2 if output_attentions else 1]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1], )
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- next_cache = None
- if use_cache:
- next_cache = (
- next_decoder_cache.to_legacy_cache()
- if use_legacy_cache else next_decoder_cache)
- if not return_dict:
- return tuple(
- v for v in
- [hidden_states, next_cache, all_hidden_states, all_self_attns]
- if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
-
-class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
- _tied_weights_keys = ['lm_head.weight']
-
- def __init__(self, config):
- super().__init__(config)
- self.model = DeepseekV2Model(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(
- config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
- @replace_return_docstrings(
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
-
- >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
- output_attentions = (
- output_attentions if output_attentions is not None else
- self.config.output_attentions)
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = (
- return_dict
- if return_dict is not None else self.config.use_return_dict)
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits, ) + outputs[1:]
- return (loss, ) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- **kwargs,
- ):
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- cache_length = past_key_values.get_seq_length()
- past_length = past_key_values.seen_tokens
- max_cache_length = past_key_values.get_max_length()
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
- # input)
- if (attention_mask is not None
- and attention_mask.shape[1] > input_ids.shape[1]):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
- past_length):]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
-
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
- if (max_cache_length is not None and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length):
- attention_mask = attention_mask[:, -max_cache_length:]
-
- position_ids = kwargs.get('position_ids', None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1]:]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {'inputs_embeds': inputs_embeds}
- else:
- model_inputs = {'input_ids': input_ids}
-
- model_inputs.update({
- 'position_ids': position_ids,
- 'past_key_values': past_key_values,
- 'use_cache': kwargs.get('use_cache'),
- 'attention_mask': attention_mask,
- })
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (tuple(
- past_state.index_select(0, beam_idx.to(past_state.device))
- for past_state in layer_past), )
- return reordered_past
-
-
-@add_start_docstrings(
- """
- The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
-
- [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
-
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- DeepseekV2_START_DOCSTRING,
-)
-class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = DeepseekV2Model(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = (
- return_dict
- if return_dict is not None else self.config.use_return_dict)
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError(
- 'Cannot handle batch sizes > 1 if no padding token is defined.'
- )
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- sequence_lengths = (torch.eq(
- input_ids, self.config.pad_token_id).int().argmax(-1) -
- 1).to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
- sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = 'regression'
- elif self.num_labels > 1 and (labels.dtype == torch.long
- or labels.dtype == torch.int):
- self.config.problem_type = 'single_label_classification'
- else:
- self.config.problem_type = 'multi_label_classification'
-
- if self.config.problem_type == 'regression':
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == 'single_label_classification':
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == 'multi_label_classification':
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits, ) + transformer_outputs[1:]
- return ((loss, ) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
diff --git a/code/xtuner/model/transformers_models/deepseek_v2/tokenization_deepseek_fast.py b/code/xtuner/model/transformers_models/deepseek_v2/tokenization_deepseek_fast.py
deleted file mode 100644
index 89e3cbb50b61c357deeb3fd37b9eab1188018172..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/deepseek_v2/tokenization_deepseek_fast.py
+++ /dev/null
@@ -1,37 +0,0 @@
-from typing import List, Optional, Union
-
-from transformers.models.llama import LlamaTokenizerFast
-
-
-class DeepseekTokenizerFast(LlamaTokenizerFast):
-
- def convert_ids_to_tokens(
- self,
- ids: Union[int, List[int]],
- skip_special_tokens: bool = False) -> Union[str, List[str]]:
- """Converts a single index or a sequence of indices in a token or a
- sequence of tokens, using the vocabulary and added tokens.
-
- Args:
- ids (`int` or `List[int]`):
- The token id (or token ids) to convert to tokens.
- skip_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not to remove special tokens in the decoding.
-
- Returns:
- `str` or `List[str]`: The decoded token(s).
- """
- if isinstance(ids, int):
- return self._convert_id_to_token(ids)
- tokens = []
- for index in ids:
- index = int(index)
- if skip_special_tokens and index in self.all_special_ids:
- continue
- token = self._tokenizer.id_to_token(index)
- tokens.append(token if token is not None else '')
- return tokens
-
- def _convert_id_to_token(self, index: int) -> Optional[str]:
- token = self._tokenizer.id_to_token(int(index))
- return token if token is not None else ''
diff --git a/code/xtuner/model/transformers_models/mixtral/__init__.py b/code/xtuner/model/transformers_models/mixtral/__init__.py
deleted file mode 100644
index aabfd89dbbd8cb1b7f3233ecf6f2bd384aaddd03..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/mixtral/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .configuration_mixtral import MixtralConfig
-from .modeling_mixtral import MixtralForCausalLM, MixtralModel
-
-__all__ = ['MixtralForCausalLM', 'MixtralModel', 'MixtralConfig']
diff --git a/code/xtuner/model/transformers_models/mixtral/configuration_mixtral.py b/code/xtuner/model/transformers_models/mixtral/configuration_mixtral.py
deleted file mode 100644
index 457aefd479f4cae837e63b3af66c25de52d5ac96..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/mixtral/configuration_mixtral.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Mixtral model configuration."""
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class MixtralConfig(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
- Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
- with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.
-
- [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B)
- [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1)
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 32000):
- Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`MixtralModel`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 14336):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*, defaults to 8):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
- The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
- allows sequence of up to 4096*32 tokens.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*):
- The id of the padding token.
- bos_token_id (`int`, *optional*, defaults to 1):
- The id of the "beginning-of-sequence" token.
- eos_token_id (`int`, *optional*, defaults to 2):
- The id of the "end-of-sequence" token.
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
- Whether the model's input and output word embeddings should be tied.
- rope_theta (`float`, *optional*, defaults to 1000000.0):
- The base period of the RoPE embeddings.
- sliding_window (`int`, *optional*):
- Sliding window attention window size. If not specified, will default to `4096`.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- num_experts_per_tok (`int`, *optional*, defaults to 2):
- The number of experts to root per-token, can be also interpreted as the `top-p` routing
- parameter
- num_local_experts (`int`, *optional*, defaults to 8):
- Number of experts per Sparse MLP layer.
- output_router_logits (`bool`, *optional*, defaults to `False`):
- Whether or not the router logits should be returned by the model. Enabling this will also
- allow the model to output the auxiliary loss. See [here]() for more details
- router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
- The aux loss factor for the total loss.
- router_jitter_noise (`float`, *optional*, defaults to 0.0):
- Amount of noise to add to the router.
- moe_implementation (`str`, *optional*, defaults to 'origin'):
- The implementation of the moe blocks. 'origin' or 'shard'.
- expert_in_one_shard (`int`, *optional*, defaults to None):
- How many expert models are integrated into a shard. It is used only
- when `moe_implementation` == 'shard'.
-
- ```python
- >>> from transformers import MixtralModel, MixtralConfig
-
- >>> # Initializing a Mixtral 7B style configuration
- >>> configuration = MixtralConfig()
-
- >>> # Initializing a model from the Mixtral 7B style configuration
- >>> model = MixtralModel(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = 'mixtral'
- keys_to_ignore_at_inference = ['past_key_values']
-
- def __init__(
- self,
- vocab_size=32000,
- hidden_size=4096,
- intermediate_size=14336,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=8,
- hidden_act='silu',
- max_position_embeddings=4096 * 32,
- initializer_range=0.02,
- rms_norm_eps=1e-5,
- use_cache=True,
- pad_token_id=None,
- bos_token_id=1,
- eos_token_id=2,
- tie_word_embeddings=False,
- rope_theta=1e6,
- sliding_window=None,
- attention_dropout=0.0,
- num_experts_per_tok=2,
- num_local_experts=8,
- output_router_logits=False,
- router_aux_loss_coef=0.001,
- router_jitter_noise=0.0,
- moe_implementation='origin',
- expert_in_one_shard=None,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.sliding_window = sliding_window
-
- # for backward compatibility
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
-
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.attention_dropout = attention_dropout
-
- self.num_experts_per_tok = num_experts_per_tok
- self.num_local_experts = num_local_experts
- self.output_router_logits = output_router_logits
- self.router_aux_loss_coef = router_aux_loss_coef
- self.router_jitter_noise = router_jitter_noise
-
- self.moe_implementation = moe_implementation
- self.expert_in_one_shard = expert_in_one_shard
-
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
diff --git a/code/xtuner/model/transformers_models/mixtral/modeling_mixtral.py b/code/xtuner/model/transformers_models/mixtral/modeling_mixtral.py
deleted file mode 100644
index 94d048fe723cb2179a696fdeb4f698fb3fd870b3..0000000000000000000000000000000000000000
--- a/code/xtuner/model/transformers_models/mixtral/modeling_mixtral.py
+++ /dev/null
@@ -1,1821 +0,0 @@
-# Modified from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/mixtral/modeling_mixtral.py
-"""PyTorch Mixtral model."""
-import inspect
-import math
-import os
-import types
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
-from transformers.cache_utils import Cache, DynamicCache
-from transformers.modeling_attn_mask_utils import (
- _prepare_4d_causal_attention_mask,
- _prepare_4d_causal_attention_mask_for_sdpa)
-from transformers.modeling_outputs import (MoeCausalLMOutputWithPast,
- MoeModelOutputWithPast,
- SequenceClassifierOutputWithPast)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
-from transformers.utils import (add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_flash_attn_2_available,
- is_flash_attn_greater_or_equal_2_10, logging,
- replace_return_docstrings)
-from transformers.utils.import_utils import is_torch_fx_available
-
-from xtuner.utils import load_state_dict_into_model
-from .configuration_mixtral import MixtralConfig
-
-if is_flash_attn_2_available():
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input # noqa
- from flash_attn.bert_padding import index_first_axis, unpad_input
-
- _flash_supports_window_size = 'window_size' in list(
- inspect.signature(flash_attn_func).parameters)
-
-# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
-# It means that the function will not be traced through and simply appear as a node in the graph.
-if is_torch_fx_available():
- if not is_torch_greater_or_equal_than_1_13:
- import torch.fx
-
- _prepare_4d_causal_attention_mask = torch.fx.wrap(
- _prepare_4d_causal_attention_mask)
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = 'MixtralConfig'
-
-
-def load_balancing_loss_func(
- gate_logits: torch.Tensor,
- num_experts: torch.Tensor = None,
- top_k=2,
- attention_mask: Optional[torch.Tensor] = None) -> float:
- r"""
- Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
-
- See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
- function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
- experts is too unbalanced.
-
- Args:
- gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
- Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
- shape [batch_size X sequence_length, num_experts].
- attention_mask (`torch.Tensor`, None):
- The attention_mask used in forward function
- shape [batch_size X sequence_length] if not None.
- num_experts (`int`, *optional*):
- Number of experts
-
- Returns:
- The auxiliary loss.
- """
- if gate_logits is None or not isinstance(gate_logits, tuple):
- return 0
-
- if isinstance(gate_logits, tuple):
- compute_device = gate_logits[0].device
- concatenated_gate_logits = torch.cat(
- [layer_gate.to(compute_device) for layer_gate in gate_logits],
- dim=0)
-
- routing_weights = torch.nn.functional.softmax(
- concatenated_gate_logits, dim=-1)
-
- _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
-
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
-
- if attention_mask is None:
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.mean(routing_weights, dim=0)
- else:
- batch_size, sequence_length = attention_mask.shape
- num_hidden_layers = concatenated_gate_logits.shape[0] // (
- batch_size * sequence_length)
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
- expert_attention_mask = (
- attention_mask[None, :, :, None, None].expand(
- (num_hidden_layers, batch_size, sequence_length, top_k,
- num_experts)).reshape(-1, top_k,
- num_experts).to(compute_device))
-
- # Compute the percentage of tokens routed to each experts
- tokens_per_expert = torch.sum(
- expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
- expert_attention_mask, dim=0)
-
- # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
- router_per_expert_attention_mask = (
- attention_mask[None, :, :, None].expand(
- (num_hidden_layers, batch_size, sequence_length,
- num_experts)).reshape(-1, num_experts).to(compute_device))
-
- # Compute the average probability of routing to these experts
- router_prob_per_expert = torch.sum(
- routing_weights * router_per_expert_attention_mask,
- dim=0) / torch.sum(
- router_per_expert_attention_mask, dim=0)
-
- overall_loss = torch.sum(tokens_per_expert *
- router_prob_per_expert.unsqueeze(0))
- return overall_loss * num_experts
-
-
-# Copied from transformers.models.llama.modeling_llama._get_unpad_data
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
-class MixtralRMSNorm(nn.Module):
-
- def __init__(self, hidden_size, eps=1e-6):
- """MixtralRMSNorm is equivalent to T5LayerNorm."""
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance +
- self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
-class MixtralRotaryEmbedding(nn.Module):
-
- def __init__(self,
- dim,
- max_position_embeddings=2048,
- base=10000,
- device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (
- self.base
- **(torch.arange(0, self.dim, 2,
- dtype=torch.int64).float().to(device) / self.dim))
- self.register_buffer('inv_freq', inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings,
- device=self.inv_freq.device,
- dtype=torch.get_default_dtype())
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(
- self.max_seq_len_cached, device=device,
- dtype=torch.int64).type_as(self.inv_freq)
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- 'cos_cached', emb.cos().to(dtype), persistent=False)
- self.register_buffer(
- 'sin_cached', emb.sin().to(dtype), persistent=False)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(
- seq_len=seq_len, device=x.device, dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-# Copied from transformers.models.llama.modeling_llama.rotate_half
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`):
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
- used to pass offsetted position ids when working with a KV-cache.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-# Copied from transformers.models.llama.modeling_llama.repeat_kv
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
-class MixtralAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper.
-
- Modified to use sliding window attention: Longformer and "Generating Long
- Sequences with Sparse Transformers".
- """
-
- def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
- 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
- 'when creating this class.')
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
- self.attention_dropout = config.attention_dropout
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
- f' and `num_heads`: {self.num_heads}).')
- self.q_proj = nn.Linear(
- self.hidden_size, self.num_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(
- self.hidden_size,
- self.num_key_value_heads * self.head_dim,
- bias=False)
- self.v_proj = nn.Linear(
- self.hidden_size,
- self.num_key_value_heads * self.head_dim,
- bias=False)
- self.o_proj = nn.Linear(
- self.num_heads * self.head_dim, self.hidden_size, bias=False)
-
- self.rotary_emb = MixtralRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads,
- self.head_dim).transpose(1, 2).contiguous()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(
- kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(
- 2, 3)) / math.sqrt(self.head_dim)
-
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
- f' {attn_weights.size()}')
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
- )
-
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(
- attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_weights = nn.functional.dropout(
- attn_weights, p=self.attention_dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
- f' {attn_output.size()}')
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
-class MixtralFlashAttention2(MixtralAttention):
- """Mixtral flash attention module.
-
- This module inherits from `MixtralAttention` as the weights of the module
- stays untouched. The only required change would be on the forward pass
- where it needs to correctly call the public API of flash attention and deal
- with padding tokens in case the input contains any of them.
- """
-
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10(
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ):
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
- 'with a layer index.')
- kv_seq_len += past_key_value.get_usable_length(
- kv_seq_len, self.layer_idx)
-
- # Because the input can be padded, the absolute sequence length depends on the max position id.
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- use_sliding_windows = (
- _flash_supports_window_size
- and getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window)
-
- if not _flash_supports_window_size:
- logger.warning_once(
- 'The current flash attention version does not support sliding window attention, for a more memory efficient implementation'
- ' make sure to upgrade flash-attn library.')
-
- if past_key_value is not None:
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
- cache_has_contents = past_key_value.get_seq_length(
- self.layer_idx) > 0
- if (getattr(self.config, 'sliding_window', None) is not None
- and kv_seq_len > self.config.sliding_window
- and cache_has_contents):
- slicing_tokens = 1 - self.config.sliding_window
-
- past_key = past_key_value[self.layer_idx][0]
- past_value = past_key_value[self.layer_idx][1]
-
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
-
- if past_key.shape[-2] != self.config.sliding_window - 1:
- raise ValueError(
- f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
- f' {past_key.shape}')
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, slicing_tokens:]
- attention_mask = torch.cat([
- attention_mask,
- torch.ones_like(attention_mask[:, -1:])
- ],
- dim=-1)
-
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- dropout_rate = 0.0 if not self.training else self.attention_dropout
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in float16 just to be sure everything works as expected.
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, '_pre_quantization_dtype'):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f'The input hidden states seems to be silently casted in float32, this might be related to'
- f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
- f' {target_dtype}.')
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- # Reashape to the expected shape for Flash Attention
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- attn_output = self._flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=dropout_rate,
- use_sliding_windows=use_sliding_windows,
- )
-
- attn_output = attn_output.reshape(bsz, q_len,
- self.hidden_size).contiguous()
- attn_output = self.o_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
- def _flash_attention_forward(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- query_length,
- dropout=0.0,
- softmax_scale=None,
- use_sliding_windows=False,
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
-
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`float`):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- use_sliding_windows (`bool`, *optional*):
- Whether to activate sliding window attention.
- """
- if not self._flash_attn_uses_top_left_mask:
- causal = self.is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
- causal = self.is_causal and query_length != 1
-
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
- query_states, key_states, value_states, attention_mask,
- query_length)
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- if not use_sliding_windows:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window,
- self.config.sliding_window),
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
- query_length)
- else:
- if not use_sliding_windows:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
- else:
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states,
- dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- window_size=(self.config.sliding_window,
- self.config.sliding_window),
- )
-
- return attn_output
-
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
- query_length):
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
-
- # On the first iteration we need to properly re-create the padding mask
- # by slicing it on the proper place
- if kv_seq_len != attention_mask.shape[-1]:
- attention_mask_num_tokens = attention_mask.shape[-1]
- attention_mask = attention_mask[:, attention_mask_num_tokens -
- kv_seq_len:]
-
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
- attention_mask)
-
- key_layer = index_first_axis(
- key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
- indices_k)
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
- indices_k)
-
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, num_heads,
- head_dim), indices_k)
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
- query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-
-# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
-class MixtralSdpaAttention(MixtralAttention):
- """Mixtral attention module using
- torch.nn.functional.scaled_dot_product_attention.
-
- This module inherits from `MixtralAttention` as the weights of the module
- stays untouched. The only changes are on the forward pass to adapt to SDPA
- API.
- """
-
- # Adapted from MixtralAttention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- 'MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(bsz, q_len, self.num_heads,
- self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(
- kv_seq_len, self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx, cache_kwargs)
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
- )
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == 'cuda' and attention_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- dropout_p=self.attention_dropout if self.training else 0.0,
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
-
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-MIXTRAL_ATTENTION_CLASSES = {
- 'eager': MixtralAttention,
- 'flash_attention_2': MixtralFlashAttention2,
- 'sdpa': MixtralSdpaAttention,
-}
-
-
-class MixtralBlockSparseTop2MLP(nn.Module):
-
- def __init__(self, config: MixtralConfig):
- super().__init__()
- self.ffn_dim = config.intermediate_size
- self.hidden_dim = config.hidden_size
-
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
-
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, hidden_states):
- current_hidden_states = self.act_fn(
- self.w1(hidden_states)) * self.w3(hidden_states)
- current_hidden_states = self.w2(current_hidden_states)
- return current_hidden_states
-
-
-class MixtralSparseMoeBlock(nn.Module):
- """This implementation is strictly equivalent to standard MoE with full
- capacity (no dropped tokens).
-
- It's faster since it formulates MoE operations in terms of block-sparse
- operations to accommodate imbalanced assignments of tokens to experts,
- whereas standard MoE either (1) drop tokens at the cost of reduced
- performance or (2) set capacity factor to number of experts and thus waste
- computation and memory on padding.
- """
-
- def __init__(self, config):
- super().__init__()
- self.hidden_dim = config.hidden_size
- self.ffn_dim = config.intermediate_size
- self.num_experts = config.num_local_experts
- self.top_k = config.num_experts_per_tok
-
- # gating
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
-
- self.experts = nn.ModuleList([
- MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)
- ])
-
- # Jitter parameters
- self.jitter_noise = config.router_jitter_noise
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """"""
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(
- 1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
- hidden_states = hidden_states.view(-1, hidden_dim)
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
-
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(
- routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
-
- final_hidden_states = torch.zeros(
- (batch_size * sequence_length, hidden_dim),
- dtype=hidden_states.dtype,
- device=hidden_states.device)
-
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
- expert_mask = torch.nn.functional.one_hot(
- selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
-
- # Loop over all available experts in the model and perform the computation on each expert
- for expert_idx in range(self.num_experts):
- expert_layer = self.experts[expert_idx]
- idx, top_x = torch.where(expert_mask[expert_idx])
-
- # Index the correct hidden states and compute the expert hidden state for
- # the current expert. We need to make sure to multiply the output hidden
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
- current_hidden_states = expert_layer(
- current_state) * routing_weights[top_x, idx, None]
-
- # However `index_add_` only support torch tensors for indexing so we'll use
- # the `top_x` tensor here.
- final_hidden_states.index_add_(
- 0, top_x, current_hidden_states.to(hidden_states.dtype))
- final_hidden_states = final_hidden_states.reshape(
- batch_size, sequence_length, hidden_dim)
- return final_hidden_states, router_logits
-
-
-class ExpertShard(nn.Module):
-
- def __init__(self, config, expert_in_one_shard=1):
- super().__init__()
- self.w1w3 = nn.Parameter(
- torch.empty(expert_in_one_shard, config.intermediate_size * 2,
- config.hidden_size))
- self.w2 = nn.Parameter(
- torch.empty(expert_in_one_shard, config.hidden_size,
- config.intermediate_size))
- self.act = ACT2FN[config.hidden_act]
- self.expert_in_one_shard = expert_in_one_shard
-
- def forward(self, hidden_states, expert_mask, routing_weights,
- final_hidden_states):
- hidden_dim = hidden_states.shape[-1]
- for expert_idx in range(self.expert_in_one_shard):
- idx, top_x = torch.where(expert_mask[expert_idx])
- current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
-
- w1w3 = self.w1w3[expert_idx]
- w2 = self.w2[expert_idx]
- gate_up_out = torch.matmul(current_state, w1w3.T)
- gate_out, up_out = gate_up_out.chunk(2, dim=-1)
- gate_out = self.act(gate_out)
- out = gate_out * up_out
- out = torch.matmul(out, w2.T)
-
- current_hidden_states = out * routing_weights[top_x, idx, None]
- final_hidden_states.index_add_(
- 0, top_x, current_hidden_states.to(hidden_states.dtype))
- return final_hidden_states
-
-
-class MixtralSparseShardMoeBlock(nn.Module):
-
- def __init__(self, config):
- super().__init__()
- self.hidden_dim = config.hidden_size
- self.ffn_dim = config.intermediate_size
- self.num_experts = config.num_local_experts
- self.top_k = config.num_experts_per_tok
-
- # gating
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
-
- expert_in_one_shard = config.expert_in_one_shard
- assert config.num_local_experts % expert_in_one_shard == 0, \
- ('num_local_experts should be divisible by expert_in_one_shard, but got '
- f'num_local_experts = {config.num_local_experts} and expert_in_one_shard = {expert_in_one_shard}')
- self.shard_num = config.num_local_experts // expert_in_one_shard
- self.expert_in_one_shard = expert_in_one_shard
- self.experts = nn.ModuleList([
- ExpertShard(config, self.expert_in_one_shard)
- for i in range(self.shard_num)
- ])
-
- # Jitter parameters
- self.jitter_noise = config.router_jitter_noise
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- """"""
- batch_size, sequence_length, hidden_dim = hidden_states.shape
- if self.training and self.jitter_noise > 0:
- hidden_states *= torch.empty_like(hidden_states).uniform_(
- 1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
- hidden_states = hidden_states.view(-1, hidden_dim)
- # router_logits: (batch * sequence_length, n_experts)
- router_logits = self.gate(hidden_states)
-
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
- routing_weights, selected_experts = torch.topk(
- routing_weights, self.top_k, dim=-1)
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
- # we cast back to the input dtype
- routing_weights = routing_weights.to(hidden_states.dtype)
-
- final_hidden_states = torch.zeros(
- (batch_size * sequence_length, hidden_dim),
- dtype=hidden_states.dtype,
- device=hidden_states.device)
-
- # One hot encode the selected experts to create an expert mask
- # this will be used to easily index which expert is going to be sollicitated
- expert_mask = torch.nn.functional.one_hot(
- selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
-
- # Loop over all available experts in the model and perform the computation on each expert
- for shard_index in range(self.shard_num):
- mask = expert_mask[shard_index *
- self.expert_in_one_shard:(shard_index + 1) *
- self.expert_in_one_shard]
- final_hidden_states = self.experts[shard_index](
- hidden_states, mask, routing_weights, final_hidden_states)
-
- final_hidden_states = final_hidden_states.reshape(
- batch_size, sequence_length, hidden_dim)
- return final_hidden_states, router_logits
-
-
-class MixtralDecoderLayer(nn.Module):
-
- def __init__(self, config: MixtralConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- self.self_attn = MIXTRAL_ATTENTION_CLASSES[
- config._attn_implementation](config, layer_idx)
-
- moe_implementation = config.moe_implementation
- if moe_implementation == 'origin':
- block = MixtralSparseMoeBlock
- elif moe_implementation == 'shard':
- block = MixtralSparseShardMoeBlock
- else:
- raise NotImplementedError
- self.block_sparse_moe = block(config)
-
- self.input_layernorm = MixtralRMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = MixtralRMSNorm(
- config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- output_router_logits: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
- torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_router_logits (`bool`, *optional*):
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
- should not be returned during inference.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- """
-
- residual = hidden_states
-
- hidden_states = self.input_layernorm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states, )
-
- if output_attentions:
- outputs += (self_attn_weights, )
-
- if use_cache:
- outputs += (present_key_value, )
-
- if output_router_logits:
- outputs += (router_logits, )
-
- return outputs
-
-
-def _load_pretrained_model(
- cls,
- model,
- state_dict,
- loaded_keys,
- resolved_archive_file,
- pretrained_model_name_or_path,
- ignore_mismatched_sizes=False,
- sharded_metadata=None,
- _fast_init=True,
- low_cpu_mem_usage=False,
- device_map=None,
- offload_folder=None,
- offload_state_dict=None,
- dtype=None,
- hf_quantizer=None,
- keep_in_fp32_modules=None,
- gguf_path=None,
-):
- if ((state_dict is not None) or (resolved_archive_file is None)
- or (low_cpu_mem_usage) or (device_map is not None)
- or (offload_folder is not None) or
- (not (offload_state_dict is None or offload_state_dict is False))
- or (hf_quantizer is not None) or
- (keep_in_fp32_modules is not None and len(keep_in_fp32_modules) > 0)
- or (gguf_path is not None)):
- raise NotImplementedError
-
- folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
- error_msgs = load_state_dict_into_model(model, folder)
- return model, [], [], [], None, error_msgs
-
-
-MIXTRAL_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`MixtralConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- 'The bare Mixtral Model outputting raw hidden-states without any specific head on top.',
- MIXTRAL_START_DOCSTRING,
-)
-# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
-class MixtralPreTrainedModel(PreTrainedModel):
- config_class = MixtralConfig
- base_model_prefix = 'model'
- supports_gradient_checkpointing = True
- _no_split_modules = ['MixtralDecoderLayer']
- _skip_keys_device_placement = 'past_key_values'
- _supports_flash_attn_2 = True
- _supports_sdpa = True
- _supports_cache_class = True
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
- moe_implementation = kwargs.get('moe_implementation', 'origin')
- if moe_implementation == 'origin':
- return super().from_pretrained(pretrained_model_name_or_path,
- *args, **kwargs)
-
- cls._load_pretrained_model = types.MethodType(_load_pretrained_model,
- cls)
- return super().from_pretrained(pretrained_model_name_or_path, *args,
- **kwargs)
-
-
-MIXTRAL_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- output_router_logits (`bool`, *optional*):
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
- should not be returned during inference.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-@add_start_docstrings(
- 'The bare Mixtral Model outputting raw hidden-states without any specific head on top.',
- MIXTRAL_START_DOCSTRING,
-)
-# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
-class MixtralModel(MixtralPreTrainedModel):
- """Transformer decoder consisting of *config.num_hidden_layers* layers.
- Each layer is a [`MixtralDecoderLayer`]
-
- Args:
- config: MixtralConfig
- """
-
- def __init__(self, config: MixtralConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
- self.padding_idx)
- self.layers = nn.ModuleList([
- MixtralDecoderLayer(config, layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ])
- self._attn_implementation = config._attn_implementation
- self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- # Ignore copy
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_router_logits: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MoeModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_router_logits = (
- output_router_logits if output_router_logits is not None else
- self.config.output_router_logits)
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time'
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError(
- 'You have to specify either decoder_input_ids or decoder_inputs_embeds'
- )
-
- past_key_values_length = 0
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
- )
- use_cache = False
-
- if use_cache:
- use_legacy_cache = not isinstance(past_key_values, Cache)
- if use_legacy_cache:
- past_key_values = DynamicCache.from_legacy_cache(
- past_key_values)
- past_key_values_length = past_key_values.get_usable_length(
- seq_length)
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device)
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
-
- if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
- if is_padding_right:
- raise ValueError(
- "You are attempting to perform batched generation with padding_side='right'"
- ' this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to '
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
- )
-
- if self._attn_implementation == 'flash_attention_2':
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (
- attention_mask is not None and 0 in attention_mask) else None
- elif self._attn_implementation == 'sdpa' and not output_attentions:
- # output_attentions=True can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- sliding_window=self.config.sliding_window,
- )
- else:
- # 4d mask is passed through the layers
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- sliding_window=self.config.sliding_window,
- )
-
- hidden_states = inputs_embeds
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- all_router_logits = () if output_router_logits else None
- next_decoder_cache = None
-
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- attention_mask,
- position_ids,
- past_key_values,
- output_attentions,
- output_router_logits,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- output_router_logits=output_router_logits,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[
- 2 if output_attentions else 1]
-
- if output_attentions:
- all_self_attns += (layer_outputs[1], )
-
- if output_router_logits:
- all_router_logits += (layer_outputs[-1], )
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states, )
-
- next_cache = None
- if use_cache:
- next_cache = next_decoder_cache.to_legacy_cache(
- ) if use_legacy_cache else next_decoder_cache
-
- if not return_dict:
- return tuple(v for v in [
- hidden_states, next_cache, all_hidden_states, all_self_attns,
- all_router_logits
- ] if v is not None)
- return MoeModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- router_logits=all_router_logits,
- )
-
-
-class MixtralForCausalLM(MixtralPreTrainedModel):
- _tied_weights_keys = ['lm_head.weight']
-
- def __init__(self, config):
- super().__init__(config)
- self.model = MixtralModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(
- config.hidden_size, config.vocab_size, bias=False)
- self.router_aux_loss_coef = config.router_aux_loss_coef
- self.num_experts = config.num_local_experts
- self.num_experts_per_tok = config.num_experts_per_tok
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
- @replace_return_docstrings(
- output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- # Ignore copy
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_router_logits: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, MixtralForCausalLM
-
- >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_router_logits = (
- output_router_logits if output_router_logits is not None else
- self.config.output_router_logits)
-
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else
- self.config.output_hidden_states)
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- output_router_logits=output_router_logits,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- aux_loss = None
- if output_router_logits:
- aux_loss = load_balancing_loss_func(
- outputs.router_logits if return_dict else outputs[-1],
- self.num_experts,
- self.num_experts_per_tok,
- attention_mask,
- )
- if labels is not None:
- loss += self.router_aux_loss_coef * aux_loss.to(
- loss.device) # make sure to reside in the same device
-
- if not return_dict:
- output = (logits, ) + outputs[1:]
- if output_router_logits:
- output = (aux_loss, ) + output
- return (loss, ) + output if loss is not None else output
-
- return MoeCausalLMOutputWithPast(
- loss=loss,
- aux_loss=aux_loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- router_logits=outputs.router_logits,
- )
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- inputs_embeds=None,
- output_router_logits=False,
- **kwargs,
- ):
- # Omit tokens covered by past_key_values
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- cache_length = past_key_values.get_seq_length()
- past_length = past_key_values.seen_tokens
- max_cache_length = past_key_values.get_max_length()
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
- # input)
- if attention_mask is not None and attention_mask.shape[
- 1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] -
- past_length):]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
-
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
- if (max_cache_length is not None and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length):
- attention_mask = attention_mask[:, -max_cache_length:]
-
- position_ids = kwargs.get('position_ids', None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1]:]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {'inputs_embeds': inputs_embeds}
- else:
- model_inputs = {'input_ids': input_ids}
-
- model_inputs.update({
- 'position_ids': position_ids,
- 'past_key_values': past_key_values,
- 'use_cache': kwargs.get('use_cache'),
- 'attention_mask': attention_mask,
- 'output_router_logits': output_router_logits,
- })
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (tuple(
- past_state.index_select(0, beam_idx.to(past_state.device))
- for past_state in layer_past), )
- return reordered_past
-
-
-@add_start_docstrings(
- """
- The Mixtral Model transformer with a sequence classification head on top (linear layer).
-
- [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
-
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- MIXTRAL_START_DOCSTRING,
-)
-# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
-class MixtralForSequenceClassification(MixtralPreTrainedModel):
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = MixtralModel(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache,
- List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError(
- 'Cannot handle batch sizes > 1 if no padding token is defined.'
- )
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
- sequence_lengths = torch.eq(
- input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device),
- sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = 'regression'
- elif self.num_labels > 1 and (labels.dtype == torch.long
- or labels.dtype == torch.int):
- self.config.problem_type = 'single_label_classification'
- else:
- self.config.problem_type = 'multi_label_classification'
-
- if self.config.problem_type == 'regression':
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == 'single_label_classification':
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == 'multi_label_classification':
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits, ) + transformer_outputs[1:]
- return ((loss, ) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
diff --git a/code/xtuner/model/utils.py b/code/xtuner/model/utils.py
deleted file mode 100644
index f12c515123c4e3791cf109b1ed25e5f069153183..0000000000000000000000000000000000000000
--- a/code/xtuner/model/utils.py
+++ /dev/null
@@ -1,330 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
-from typing import List, Optional
-
-import torch
-from mmengine.utils.misc import get_object_from_string
-from peft import PeftType
-from torch import nn
-from transformers import PreTrainedModel
-from mmengine import print_log
-
-from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX
-
-
-def set_obj_dtype(d):
- for key, value in d.items():
- if value in ['torch.float16', 'torch.float32', 'torch.bfloat16']:
- d[key] = getattr(torch, value.split('.')[-1])
-
-
-def try_build_module(cfg):
- builder = cfg['type']
- if isinstance(builder, str):
- builder = get_object_from_string(builder)
- if builder is None:
- # support handling cfg with key 'type' can not be built, such as
- # {'rope_scaling': {'type': 'linear', 'factor': 2.0}}
- return cfg
- cfg.pop('type')
- module_built = builder(**cfg)
- return module_built
-
-
-def traverse_dict(d):
- if isinstance(d, dict):
- set_obj_dtype(d)
- for key, value in d.items():
- if isinstance(value, dict):
- traverse_dict(value)
- if 'type' in value:
- module_built = try_build_module(value)
- d[key] = module_built
- elif isinstance(d, list):
- for element in d:
- traverse_dict(element)
-
-
-def find_all_linear_names(model):
- lora_module_names = set()
- for name, module in model.named_modules():
- if isinstance(module, nn.Linear):
- names = name.split('.')
- lora_module_names.add(names[0] if len(names) == 1 else names[-1])
-
- if 'lm_head' in lora_module_names: # needed for 16-bit
- lora_module_names.remove('lm_head')
- if 'output_layer' in lora_module_names: # needed for 16-bit
- lora_module_names.remove('output_layer')
- return list(lora_module_names)
-
-def find_all_linear_names_for_dynamic_qwen2(_model_unused=None):
- return [
- "q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",
- ]
-
-class LoadWoInit:
- """Context manager that disable parameter initialization."""
-
- def __init__(self):
- self.constant_ = torch.nn.init.constant_
- self.zeros_ = torch.nn.init.zeros_
- self.ones_ = torch.nn.init.ones_
- self.uniform_ = torch.nn.init.uniform_
- self.normal_ = torch.nn.init.normal_
- self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_
- self.kaiming_normal_ = torch.nn.init.kaiming_normal_
-
- def __enter__(self, *args, **kwargs):
- torch.nn.init.constant_ = lambda *args, **kwargs: None
- torch.nn.init.zeros_ = lambda *args, **kwargs: None
- torch.nn.init.ones_ = lambda *args, **kwargs: None
- torch.nn.init.uniform_ = lambda *args, **kwargs: None
- torch.nn.init.normal_ = lambda *args, **kwargs: None
- torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None
- torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
-
- def __exit__(self, *args, **kwargs):
- torch.nn.init.constant_ = self.constant_
- torch.nn.init.zeros_ = self.zeros_
- torch.nn.init.ones_ = self.ones_
- torch.nn.init.uniform_ = self.uniform_
- torch.nn.init.normal_ = self.normal_
- torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_
- torch.nn.init.kaiming_normal_ = self.kaiming_normal_
-
-
-def get_peft_model_state_dict(model, state_dict=None, adapter_name='default'):
- # Modified from `https://github.com/huggingface/peft/blob/main/src/peft/utils/save_and_load.py` # noqa: E501
-
- config = model.peft_config[adapter_name]
- if state_dict is None:
- state_dict = model.state_dict()
- if config.peft_type == PeftType.LORA:
- # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` # noqa: E501
- # to be used directly with the state dict which is necessary
- # when using DeepSpeed or FSDP
- bias = config.bias
- if bias == 'none':
- to_return = {k: state_dict[k] for k in state_dict if 'lora_' in k}
- elif bias == 'all':
- to_return = {
- k: state_dict[k]
- for k in state_dict if 'lora_' in k or 'bias' in k
- }
- elif bias == 'lora_only':
- to_return = {}
- for k in state_dict:
- if 'lora_' in k:
- to_return[k] = state_dict[k]
- bias_name = k.split('lora_')[0] + 'bias'
- if bias_name in state_dict:
- to_return[bias_name] = state_dict[bias_name]
- else:
- raise NotImplementedError
- to_return = {
- k: v
- for k, v in to_return.items()
- if (('lora_' in k and adapter_name in k) or ('bias' in k))
- }
- else:
- # Currently we only support lora
- raise NotImplementedError
- if model.modules_to_save is not None:
- for key, value in state_dict.items():
- if any(f'{module_name}.modules_to_save.{adapter_name}' in key
- for module_name in model.modules_to_save):
- to_return[key] = value
-
- return to_return
-
-
-# Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
-def prepare_inputs_labels_for_multimodal(
- llm: PreTrainedModel,
- input_ids: torch.LongTensor = None,
- position_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- labels: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- text_features: Optional[torch.FloatTensor] = None,
- ):
- if pixel_values is None:
- return {
- 'input_ids': input_ids,
- 'position_ids': position_ids,
- 'attention_mask': attention_mask,
- 'past_key_values': past_key_values,
- 'inputs_embeds': None,
- 'labels': labels
- }
-
- _labels = labels
- _position_ids = position_ids
- _attention_mask = attention_mask
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
- else:
- attention_mask = attention_mask.bool()
- if position_ids is None:
- position_ids = torch.arange(
- 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
- if labels is None:
- labels = torch.full_like(input_ids, IGNORE_INDEX)
-
- # remove the padding using attention_mask -- TODO: double check
- input_ids = [
- cur_input_ids[cur_attention_mask]
- for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
- ]
- labels = [
- cur_labels[cur_attention_mask]
- for cur_labels, cur_attention_mask in zip(labels, attention_mask)
- ]
-
- new_inputs_embeds = []
- new_labels = []
- cur_image_idx = 0
- for batch_idx, cur_input_ids in enumerate(input_ids):
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
- # print_log(f"the number of images {num_images}", 'current')
- if num_images == 0:
- cur_pixel_values = pixel_values[cur_image_idx]
- cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) if text_features is None else text_features[batch_idx]
- cur_inputs_embeds = torch.cat(
- [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
- new_inputs_embeds.append(cur_inputs_embeds)
- new_labels.append(labels[batch_idx])
- cur_image_idx += 1
- continue
-
- image_token_indices = [-1] + torch.where(
- cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
- cur_input_ids.shape[0]
- ]
- cur_input_ids_noim = []
- cur_labels = labels[batch_idx]
- cur_labels_noim = []
- for i in range(len(image_token_indices) - 1):
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] +
- 1:image_token_indices[i +
- 1]])
- cur_labels_noim.append(cur_labels[image_token_indices[i] +
- 1:image_token_indices[i + 1]])
- split_sizes = [x.shape[0] for x in cur_labels_noim]
- cur_inputs_embeds = llm.get_input_embeddings()(
- torch.cat(cur_input_ids_noim))
- cur_inputs_embeds_no_im = torch.split(
- cur_inputs_embeds, split_sizes, dim=0)
- cur_new_inputs_embeds = []
- cur_new_labels = []
-
- for i in range(num_images + 1):
- cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
- cur_new_labels.append(cur_labels_noim[i])
- if i < num_images:
- cur_pixel_values = pixel_values[cur_image_idx]
- cur_image_idx += 1
- cur_new_inputs_embeds.append(cur_pixel_values)
- cur_new_labels.append(
- torch.full((cur_pixel_values.shape[0], ),
- IGNORE_INDEX,
- device=cur_labels.device,
- dtype=cur_labels.dtype))
-
- cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
- cur_new_labels = torch.cat(cur_new_labels)
-
- new_inputs_embeds.append(cur_new_inputs_embeds)
- new_labels.append(cur_new_labels)
-
- # Combine them
- max_len = max(x.shape[0] for x in new_inputs_embeds)
- batch_size = len(new_inputs_embeds)
-
- new_inputs_embeds_padded = []
- new_labels_padded = torch.full((batch_size, max_len),
- IGNORE_INDEX,
- dtype=new_labels[0].dtype,
- device=new_labels[0].device)
- attention_mask = torch.zeros((batch_size, max_len),
- dtype=attention_mask.dtype,
- device=attention_mask.device)
- position_ids = torch.zeros((batch_size, max_len),
- dtype=position_ids.dtype,
- device=position_ids.device)
-
- for i, (cur_new_embed,
- cur_new_labels) in enumerate(zip(new_inputs_embeds, new_labels)):
- cur_len = cur_new_embed.shape[0]
- new_inputs_embeds_padded.append(
- torch.cat((cur_new_embed,
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]),
- dtype=cur_new_embed.dtype,
- device=cur_new_embed.device)),
- dim=0))
- if cur_len > 0:
- new_labels_padded[i, :cur_len] = cur_new_labels
- attention_mask[i, :cur_len] = True
- position_ids[i, :cur_len] = torch.arange(
- 0,
- cur_len,
- dtype=position_ids.dtype,
- device=position_ids.device)
-
- new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
-
- if _labels is None:
- new_labels = None
- else:
- new_labels = new_labels_padded
-
- if _attention_mask is None:
- attention_mask = None
- else:
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
-
- if _position_ids is None:
- position_ids = None
-
- return {
- 'input_ids': None,
- 'position_ids': position_ids,
- 'attention_mask': attention_mask,
- 'past_key_values': past_key_values,
- 'inputs_embeds': new_inputs_embeds,
- 'labels': new_labels
- }
-
-
-
-
-
-
-def make_inputs_require_grad(module, input, output):
- output.requires_grad_(True)
-
-
-def guess_load_checkpoint(pth_model):
- if osp.isfile(pth_model):
- state_dict = torch.load(pth_model, map_location='cpu')
- if 'state_dict' in state_dict:
- state_dict = state_dict['state_dict']
- elif osp.isdir(pth_model):
- try:
- from xtuner.utils.zero_to_any_dtype import \
- get_state_dict_from_zero_checkpoint
- except ImportError:
- raise ImportError(
- 'The provided PTH model appears to be a DeepSpeed checkpoint. '
- 'However, DeepSpeed library is not detected in current '
- 'environment. This suggests that DeepSpeed may not be '
- 'installed or is incorrectly configured. Please verify your '
- 'setup.')
- state_dict = get_state_dict_from_zero_checkpoint(
- osp.dirname(pth_model), osp.basename(pth_model))
- else:
- raise FileNotFoundError(f'Cannot find {pth_model}')
- return state_dict
diff --git a/code/xtuner/parallel/__init__.py b/code/xtuner/parallel/__init__.py
deleted file mode 100644
index 8c726230c8b8e703359ea62ff1edab1fea420052..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .sequence import * # noqa: F401, F403
diff --git a/code/xtuner/parallel/__pycache__/__init__.cpython-311.pyc b/code/xtuner/parallel/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index c19791ddeaf68be961a91fc37ce477ccf0800f52..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__init__.py b/code/xtuner/parallel/sequence/__init__.py
deleted file mode 100644
index 6e2992f78aa84f860b4465860d891b67900276f7..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.dist import init_dist
-
-from .attention import (post_process_for_sequence_parallel_attn,
- pre_process_for_sequence_parallel_attn,
- sequence_parallel_wrapper)
-from .comm import (all_to_all, gather_for_sequence_parallel,
- gather_forward_split_backward, split_for_sequence_parallel,
- split_forward_gather_backward)
-from .data_collate import (pad_cumulative_len_for_sequence_parallel,
- pad_for_sequence_parallel)
-from .reduce_loss import reduce_sequence_parallel_loss
-from .sampler import SequenceParallelSampler
-from .setup_distributed import (get_data_parallel_group,
- get_data_parallel_rank,
- get_data_parallel_world_size,
- get_inner_sequence_parallel_group,
- get_inner_sequence_parallel_rank,
- get_inner_sequence_parallel_world_size,
- get_sequence_parallel_group,
- get_sequence_parallel_rank,
- get_sequence_parallel_world_size,
- init_inner_sequence_parallel,
- init_sequence_parallel,
- is_inner_sequence_parallel_initialized)
-
-__all__ = [
- 'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
- 'post_process_for_sequence_parallel_attn', 'pad_for_sequence_parallel',
- 'split_for_sequence_parallel', 'SequenceParallelSampler',
- 'init_sequence_parallel', 'get_sequence_parallel_group',
- 'get_sequence_parallel_world_size', 'get_sequence_parallel_rank',
- 'get_data_parallel_group', 'get_data_parallel_world_size',
- 'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist',
- 'all_to_all', 'gather_for_sequence_parallel',
- 'split_forward_gather_backward', 'gather_forward_split_backward',
- 'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank',
- 'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel',
- 'is_inner_sequence_parallel_initialized',
- 'pad_cumulative_len_for_sequence_parallel'
-]
diff --git a/code/xtuner/parallel/sequence/__pycache__/__init__.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 3bfda9a2e33ececffeb08461328dcd281d00a570..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/attention.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/attention.cpython-311.pyc
deleted file mode 100644
index 1533869f66dea79bb818bbe7803244a5c8cc0581..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/attention.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/comm.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/comm.cpython-311.pyc
deleted file mode 100644
index 8c96c2b133815261b911825967195e665fa6a499..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/comm.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/data_collate.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/data_collate.cpython-311.pyc
deleted file mode 100644
index 79843f3cddba769450074760863bb793e3d39723..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/data_collate.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/reduce_loss.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/reduce_loss.cpython-311.pyc
deleted file mode 100644
index 23167035cb2e399e42eeb297629e51243c50d7ad..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/reduce_loss.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/sampler.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/sampler.cpython-311.pyc
deleted file mode 100644
index 5506513fcc1a6f323a8885f7c34506f2833070a3..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/sampler.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/__pycache__/setup_distributed.cpython-311.pyc b/code/xtuner/parallel/sequence/__pycache__/setup_distributed.cpython-311.pyc
deleted file mode 100644
index 4fef4cdb4579b1c519166b01c5eeba7483c55c84..0000000000000000000000000000000000000000
Binary files a/code/xtuner/parallel/sequence/__pycache__/setup_distributed.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/parallel/sequence/attention.py b/code/xtuner/parallel/sequence/attention.py
deleted file mode 100644
index e8bb1adaca8bd42123976c46431cfba10c21fe96..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/attention.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-
-import torch.distributed as dist
-
-from .comm import (all_to_all, gather_forward_split_backward,
- split_forward_gather_backward)
-from .setup_distributed import (get_inner_sequence_parallel_group,
- get_inner_sequence_parallel_world_size,
- get_sequence_parallel_group,
- get_sequence_parallel_world_size,
- init_inner_sequence_parallel,
- is_inner_sequence_parallel_initialized)
-
-
-def pre_process_for_sequence_parallel_attn(query_states,
- key_states,
- value_states,
- scatter_dim=2,
- gather_dim=1):
- b, s_div_sp, h, d = query_states.shape
- sp = get_sequence_parallel_world_size()
-
- if not is_inner_sequence_parallel_initialized():
- insp = sp // math.gcd(h, sp)
- init_inner_sequence_parallel(insp)
- else:
- insp = get_inner_sequence_parallel_world_size()
-
- def pre_process_for_inner_sp(q, k, v):
- if scatter_dim != 2 and gather_dim != 1:
- raise NotImplementedError(
- 'Currently only `scatter_dim == 2` and `gather_dim == 1` '
- f'is supported. But got scatter_dim = {scatter_dim} and '
- f'gather_dim = {gather_dim}.')
-
- # (b, s_div_sp, h, d) ->
- # (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
- # (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
- # (b, s_div_sp, insp*h, d/insp)
- q = q.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
- d // insp).transpose(3, 4).flatten(2, 4)
- k = k.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
- d // insp).transpose(3, 4).flatten(2, 4)
- v = v.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
- d // insp).transpose(3, 4).flatten(2, 4)
-
- return q, k, v
-
- def post_process_for_inner_sp(q, k, v):
- # (b, s, insp*h/sp, d/insp) -> (b, s, insp*h/sp, d)
- q = gather_forward_split_backward(q, -1,
- get_inner_sequence_parallel_group())
- k = gather_forward_split_backward(k, -1,
- get_inner_sequence_parallel_group())
- v = gather_forward_split_backward(v, -1,
- get_inner_sequence_parallel_group())
-
- return q, k, v
-
- assert (h * insp) % sp == 0, \
- ('The number of attention heads should be divisible by '
- '(sequence_parallel_world_size // sequence_parallel_inner_world_size)'
- f'. But got n_head = {h}, sequence_parallel_world_size = '
- f'{sp} and sequence_parallel_inner_world_size = {insp}.')
-
- if insp > 1:
- query_states, key_states, value_states = pre_process_for_inner_sp(
- query_states, key_states, value_states)
-
- # (b, s_div_sp, insp*h, d/insp) -> (b, s, insp*h/sp, d/insp)
- sequence_parallel_group = get_sequence_parallel_group()
- query_states = all_to_all(
- query_states,
- sequence_parallel_group,
- scatter_dim=scatter_dim,
- gather_dim=gather_dim)
- key_states = all_to_all(
- key_states,
- sequence_parallel_group,
- scatter_dim=scatter_dim,
- gather_dim=gather_dim)
- value_states = all_to_all(
- value_states,
- sequence_parallel_group,
- scatter_dim=scatter_dim,
- gather_dim=gather_dim)
-
- if insp > 1:
- query_states, key_states, value_states = post_process_for_inner_sp(
- query_states, key_states, value_states)
-
- return query_states, key_states, value_states
-
-
-def post_process_for_sequence_parallel_attn(attn_output,
- scatter_dim=1,
- gather_dim=2):
- sp = get_sequence_parallel_world_size()
- insp = get_inner_sequence_parallel_world_size()
- b, s, h_mul_insp_div_sp, d = attn_output.shape
- h = h_mul_insp_div_sp * sp // insp
- s_div_sp = s // sp
-
- if insp > 1:
- # (b, s, insp*h/sp, d) -> (b, s, insp*h/sp, d/insp)
- attn_output = split_forward_gather_backward(
- attn_output, -1, get_inner_sequence_parallel_group())
-
- # (b, s, insp*h/sp, d/insp) -> (b, s_div_sp, insp*h, d/insp)
- sequence_parallel_group = get_sequence_parallel_group()
- output = all_to_all(
- attn_output,
- sequence_parallel_group,
- scatter_dim=scatter_dim,
- gather_dim=gather_dim)
-
- if insp > 1:
- # (b, s_div_sp, insp*h, d/insp) ->
- # (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
- # (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
- # (b, s_div_sp, h, d)
- output = output.view(b, s_div_sp, sp // insp, insp, h * insp // sp,
- d // insp).transpose(3, 4).reshape(
- b, s_div_sp, h, d)
-
- return output
-
-
-def sequence_parallel_wrapper(local_attn):
-
- def sequence_parallel_attn(query_states, key_states, value_states, *args,
- **kwargs):
- training = kwargs.pop('training', True)
- enable_sequence_parallel = (
- dist.is_initialized() and get_sequence_parallel_world_size() > 1
- and training)
- if enable_sequence_parallel:
- query_states, key_states, value_states = \
- pre_process_for_sequence_parallel_attn(
- query_states, key_states, value_states)
-
- out = local_attn(query_states, key_states, value_states, *args,
- **kwargs)
-
- if enable_sequence_parallel:
- out = post_process_for_sequence_parallel_attn(out).contiguous()
-
- return out
-
- return sequence_parallel_attn
diff --git a/code/xtuner/parallel/sequence/comm.py b/code/xtuner/parallel/sequence/comm.py
deleted file mode 100644
index 1ff78e68c138dbf68cbda363424e460eac614b19..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/comm.py
+++ /dev/null
@@ -1,269 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Tuple
-
-import torch
-import torch.distributed as dist
-from torch import Tensor
-
-
-def _all_to_all(
- input: Tensor,
- world_size: int,
- group: dist.ProcessGroup,
- scatter_dim: int,
- gather_dim: int,
-):
- input_list = [
- t.contiguous()
- for t in torch.tensor_split(input, world_size, scatter_dim)
- ]
- output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
- dist.all_to_all(output_list, input_list, group=group)
- return torch.cat(output_list, dim=gather_dim).contiguous()
-
-
-class _AllToAll(torch.autograd.Function):
- """All-to-all communication.
-
- Args:
- input: Input tensor
- sp_group: Sequence parallel process group
- scatter_dim: Scatter dimension
- gather_dim: Gather dimension
- """
-
- @staticmethod
- def forward(ctx: Any, input: Tensor, sp_group: dist.ProcessGroup,
- scatter_dim: int, gather_dim: int):
- ctx.sp_group = sp_group
- ctx.scatter_dim = scatter_dim
- ctx.gather_dim = gather_dim
- ctx.world_size = dist.get_world_size(sp_group)
- output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim,
- gather_dim)
- return output
-
- @staticmethod
- def backward(ctx: Any, grad_output: Tensor) -> Tuple:
- grad_output = _all_to_all(
- grad_output,
- ctx.world_size,
- ctx.sp_group,
- ctx.gather_dim,
- ctx.scatter_dim,
- )
- return (
- grad_output,
- None,
- None,
- None,
- )
-
-
-def all_to_all(
- input: Tensor,
- sp_group: dist.ProcessGroup,
- scatter_dim: int = 2,
- gather_dim: int = 1,
-):
- """Convenience function to apply the all-to-all operation with scatter and
- gather dimensions.
-
- Notes:
- We have wrapped the `torch.distributed.all_to_all` function to
- enable automatic differentiation of the all-to-all operation.
-
- Args:
- input: The input tensor for which all-to-all communication is performed
- sp_group: The sequence parallel process group.
- scatter_dim: The dimension along which the input tensor is scattered
- (default: 2).
- gather_dim: The dimension along which the output tensor is gathered
- (default: 1).
-
- Returns:
- The output tensor after the all-to-all communication.
- """
- return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
-
-
-def split_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
- """Splits the input tensor along a given dimension for sequence parallel.
-
- Args:
- input: The input tensor to be split.
- dim: The dimension along which the tensor should be split.
- sp_group: The sequence parallel process group.
-
- Returns:
- The split tensor corresponding to the current rank's chunk.
- """
- world_size = dist.get_world_size(sp_group)
- if world_size == 1:
- return input
-
- rank = dist.get_rank(sp_group)
- dim_size = input.size(dim)
- assert dim_size % world_size == 0, (
- f'The dimension to split ({dim_size}) is not a multiple of '
- f'world size ({world_size}), cannot split tensor evenly')
-
- tensor_list = torch.split(input, dim_size // world_size, dim=dim)
- output = tensor_list[rank].contiguous()
-
- return output
-
-
-def gather_for_sequence_parallel(input, dim: int, sp_group: dist.ProcessGroup):
- """Gathers the input tensor along a given dimension for sequence parallel.
-
- Args:
- input: The input tensor to be gathered.
- dim: The dimension along which the tensor should be gathered.
- sp_group: The sequence parallel process group.
-
- Returns:
- The gathered tensor concatenated along the specified dimension.
- """
- input = input.contiguous()
- world_size = dist.get_world_size(sp_group)
- dist.get_rank(sp_group)
-
- if world_size == 1:
- return input
-
- tensor_list = [torch.empty_like(input) for _ in range(world_size)]
- assert input.device.type == 'cuda'
- dist.all_gather(tensor_list, input, group=sp_group)
-
- output = torch.cat(tensor_list, dim=dim).contiguous()
-
- return output
-
-
-class _GatherForwardSplitBackward(torch.autograd.Function):
- """Gather the input during forward.
-
- Scale and split the grad and keep only the corresponding chuck to the rank
- during backward.
- """
-
- @staticmethod
- def forward(ctx, input, dim, sp_group, grad_scale):
- ctx.dim = dim
- ctx.sp_group = sp_group
- ctx.grad_scale = grad_scale
- return gather_for_sequence_parallel(input, dim, sp_group)
-
- @staticmethod
- def backward(ctx, grad_output):
- if ctx.grad_scale == 'up':
- grad_output = grad_output * dist.get_world_size(ctx.sp_group)
- elif ctx.grad_scale == 'down':
- grad_output = grad_output / dist.get_world_size(ctx.sp_group)
-
- return (split_for_sequence_parallel(grad_output, ctx.dim,
- ctx.sp_group), None, None, None)
-
-
-class _SplitForwardGatherBackward(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank during
- forward.
-
- Scale and gather the grad during backward.
- """
-
- @staticmethod
- def forward(ctx, input, dim, sp_group, grad_scale):
- ctx.dim = dim
- ctx.sp_group = sp_group
- ctx.grad_scale = grad_scale
- return split_for_sequence_parallel(input, dim, sp_group)
-
- @staticmethod
- def backward(ctx, grad_output):
- if ctx.grad_scale == 'up':
- grad_output = grad_output * dist.get_world_size(ctx.sp_group)
- elif ctx.grad_scale == 'down':
- grad_output = grad_output / dist.get_world_size(ctx.sp_group)
- return (gather_for_sequence_parallel(grad_output, ctx.dim,
- ctx.sp_group), None, None, None)
-
-
-def split_forward_gather_backward(input, dim, sp_group, grad_scale=None):
- """Split tensors according to the sp rank during forward propagation and
- gather the grad from the whole sp group during backward propagation.
-
- 1. When do we need this? input.requires_grad = True
-
- 2. Why we need grad scale?
-
- We have to scale down the grads as `gather_forward_split_backward` scales
- up the grads.
- """
- return _SplitForwardGatherBackward.apply(input, dim, sp_group, grad_scale)
-
-
-def gather_forward_split_backward(input, dim, sp_group, grad_scale=None):
- """Gather tensors from the whole sp group during forward propagation and
- split the grad according to the sp rank during backward propagation.
-
- 1. When do we need this?
-
- When sp is greater than 1, we need to slice the input `x` along
- sequence length dimension before it is passed into the model and get
- `sub_seq_x`. We then pass `sub_seq_x` into model and get output
- `sub_seq_out`. If the loss calculation process needs to use the complete
- output, we have to gather the `sub_seq_out` in all sp ranks during forward
- propagation and split the grad during backward propagation.
-
- 2. Why we need grad scale?
- Here is a simple case.
-
- -------- SP 1 -----------
- Suppose here is a toy model with only one linear module
- (in_features = 2, out_features = 1) and the input x has shape(2, 2).
- Y = [[y1], = [[w11x11 + w21x12], = [[x11, x12], dot [[w11],
- [y2]] [w11x21 + w21x22]] [x21, x22]] [w21]]
- z = mean(Y) = (y1 + y2) / 2
- Here is the partial derivative of z with respect to w11:
- ∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 + ∂z / ∂y2 * ∂y2 / ∂w11
- = 1/2 * x11 + 1/2 * x21 = (x11 + x21) / 2
-
- -------- SP 2 -----------
- When sequence parallel world size is set to 2, we will split the input x
- and scatter them to the two rank in the same sequence parallel group.
- ```Step 1
- Y_rank0 = [[y1]] = [[w11x11 + w21x12]] = [[x11, x12]] dot [[w11, w21]]^T
- Y_rank1 = [[y2]] = [[w11x21 + w21x22]] = [[x21, x22]] dot [[w11, w21]]^T
- ```
-
- Then, we have to gather them:
- ```Step 2
- Y_rank0 = [[y1],
- detach([y2])]
- Y_rank1 = [detach([y1]),
- [y2]]
- ```
- Note that y2 in Y_rank0 does not have grad, neither does y1 in Y_rank1.
-
- Similarly, we calculate the loss in each rank:
- ```Step 3
- z_rank0 = mean(Y_rank0) = (y1 + detach(y2)) / 2
- z_rank1 = mean(Y_rank1) = (detach(y1) + y2) / 2
- ```
- So the partial derivative of loss_rank0 with respect to w11:
- ```∂z / ∂w11 = ∂z / ∂y1 * ∂y1 / ∂w11 = x11 / 2```
- The same for rank1:
- ```∂z / ∂w11 = ∂z / ∂y2 * ∂y2 / ∂w11 = x21 / 2```
-
- Finally, we need to all_reduce them:
- ```Step 4
- In both rank:
- ∂z / ∂w11 = (x11 / 2 + x21 / 2) / 2 = (x11 + x21) / 4
- ```
-
- In SP2, the gradient of each param is only half of that in SP1.
- So we should scale up the grad during the backward process in Step 2.
- """ # noqa: E501
- return _GatherForwardSplitBackward.apply(input, dim, sp_group, grad_scale)
diff --git a/code/xtuner/parallel/sequence/data_collate.py b/code/xtuner/parallel/sequence/data_collate.py
deleted file mode 100644
index 048eaec103be1ab1108fcf817f5d4ed4d5ece9ab..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/data_collate.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-
-from .setup_distributed import get_sequence_parallel_world_size
-
-
-def pad_for_sequence_parallel(tensor, padding_value, dim=-1):
- length = tensor.shape[dim]
- seq_parallel_world_size = get_sequence_parallel_world_size()
- if length % seq_parallel_world_size == 0:
- return tensor
-
- pad_num = seq_parallel_world_size - (length % seq_parallel_world_size)
- pad_shape = (*tensor.shape[:dim], pad_num,
- *tensor.shape[dim + 1:]) if dim != -1 else (
- *tensor.shape[:dim], pad_num)
- pad = torch.full(
- pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
- tensor = torch.cat([tensor, pad], dim=dim)
- return tensor
-
-
-# This function only meets the following two conditions:
-# 1. use_varlen_attn = True
-# 2. pack_to_max_length = True and the lengths of each sequence are different
-def pad_cumulative_len_for_sequence_parallel(cumulative_len):
- assert len(cumulative_len) == 1
- seqlen = cumulative_len[0][-1]
- seq_parallel_world_size = get_sequence_parallel_world_size()
- if seqlen % seq_parallel_world_size == 0:
- return cumulative_len, None
-
- bs = len(cumulative_len)
- pad_len = seq_parallel_world_size - (seqlen % seq_parallel_world_size)
- seqlen_new = seqlen + pad_len
- attention_mask = torch.zeros(
- bs, seqlen_new, dtype=torch.bool, device=cumulative_len[0].device)
- attention_mask[:, :seqlen] = True
-
- for i, cu_len in enumerate(cumulative_len):
- pad = torch.tensor([seqlen_new],
- device=cu_len.device,
- dtype=cu_len.dtype)
- cumulative_len[i] = torch.cat([cu_len, pad], dim=0)
-
- return cumulative_len, attention_mask
diff --git a/code/xtuner/parallel/sequence/reduce_loss.py b/code/xtuner/parallel/sequence/reduce_loss.py
deleted file mode 100644
index fb37242a33d814826e11d985924105064d131b79..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/reduce_loss.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-import torch.distributed as dist
-
-from .setup_distributed import get_sequence_parallel_group
-
-
-class _ReduceLoss(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, mean_loss, loss_scale, process_group):
- ctx.mode = process_group
- if loss_scale == 0:
- # convert nan to 0 just for logging
- mean_loss = torch.nan_to_num(mean_loss)
- loss_sum = mean_loss * loss_scale
- dist.all_reduce(loss_sum, group=process_group)
- dist.all_reduce(loss_scale, group=process_group)
- loss = loss_sum / loss_scale
- return loss
-
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output, None, None
-
-
-def reduce_sequence_parallel_loss(mean_loss,
- loss_scale,
- sp_group: dist.ProcessGroup = None):
- if dist.get_world_size(sp_group) == 1:
- return mean_loss
- if sp_group is None:
- # avoid bc breaking
- sp_group = get_sequence_parallel_group()
- return _ReduceLoss.apply(mean_loss, loss_scale, sp_group)
diff --git a/code/xtuner/parallel/sequence/sampler.py b/code/xtuner/parallel/sequence/sampler.py
deleted file mode 100644
index 69adb7cc91c5e5603b47fbb5cd438165d522a79b..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/sampler.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-from typing import Optional, Sized
-
-from mmengine.dataset import DefaultSampler
-from mmengine.dist import sync_random_seed
-
-from .setup_distributed import (get_data_parallel_rank,
- get_data_parallel_world_size)
-
-
-class SequenceParallelSampler(DefaultSampler):
-
- def __init__(self,
- dataset: Sized,
- shuffle: bool = True,
- seed: Optional[int] = None,
- round_up: bool = True) -> None:
- rank = get_data_parallel_rank()
- world_size = get_data_parallel_world_size()
- self.rank = rank
- self.world_size = world_size
-
- self.dataset = dataset
- self.shuffle = shuffle
- if seed is None:
- seed = sync_random_seed()
- self.seed = seed
- self.epoch = 0
- self.round_up = round_up
-
- if self.round_up:
- self.num_samples = math.ceil(len(self.dataset) / world_size)
- self.total_size = self.num_samples * self.world_size
- else:
- self.num_samples = math.ceil(
- (len(self.dataset) - rank) / world_size)
- self.total_size = len(self.dataset)
diff --git a/code/xtuner/parallel/sequence/setup_distributed.py b/code/xtuner/parallel/sequence/setup_distributed.py
deleted file mode 100644
index 473993a33f3f2e782e6f78594acc2bdcc120422b..0000000000000000000000000000000000000000
--- a/code/xtuner/parallel/sequence/setup_distributed.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch.distributed as dist
-
-_SEQUENCE_PARALLEL_GROUP = None
-_SEQUENCE_PARALLEL_WORLD_SIZE = None
-_SEQUENCE_PARALLEL_RANK = None
-
-_INNER_SEQUENCE_PARALLEL_GROUP = None
-_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = None
-_INNER_SEQUENCE_PARALLEL_RANK = None
-
-_DATA_PARALLEL_GROUP = None
-_DATA_PARALLEL_WORLD_SIZE = None
-_DATA_PARALLEL_RANK = None
-
-
-def init_sequence_parallel(sequence_parallel_size: int = 1):
- assert dist.is_initialized()
- world_size: int = dist.get_world_size()
-
- # enable_ds_sequence_parallel = sequence_parallel_size > 1
- # if enable_ds_sequence_parallel:
- if world_size % sequence_parallel_size != 0:
- raise RuntimeError(f'world_size ({world_size}) is not divisible by '
- f'sequence_parallel_size {sequence_parallel_size}')
-
- num_sequence_parallel_groups: int = world_size // sequence_parallel_size
-
- rank = dist.get_rank()
-
- # Build the sequence parallel groups.
- global _SEQUENCE_PARALLEL_GROUP
- assert _SEQUENCE_PARALLEL_GROUP is None, \
- 'sequence parallel group is already initialized'
- for i in range(num_sequence_parallel_groups):
- ranks = range(i * sequence_parallel_size,
- (i + 1) * sequence_parallel_size)
- group = dist.new_group(ranks)
- if rank in ranks:
- _SEQUENCE_PARALLEL_GROUP = group
-
- global _DATA_PARALLEL_GROUP
- assert _DATA_PARALLEL_GROUP is None, \
- 'data parallel group is already initialized'
- all_data_parallel_group_ranks = []
- start_rank = 0
- end_rank = world_size
- for j in range(sequence_parallel_size):
- ranks = range(start_rank + j, end_rank, sequence_parallel_size)
- all_data_parallel_group_ranks.append(list(ranks))
- group = dist.new_group(ranks)
- if rank in ranks:
- _DATA_PARALLEL_GROUP = group
-
-
-def init_inner_sequence_parallel(inner_sequence_parallel_size: int = 1):
- """Build the sequence parallel inner groups.
-
- They are helpful when sp size is not evenly divided by the number of attn
- heads.
- """
- assert _SEQUENCE_PARALLEL_GROUP is not None, \
- ('Please call `init_inner_sequence_parallel` after calling '
- '`init_sequence_parallel`.')
-
- rank = dist.get_rank()
- world_size: int = dist.get_world_size()
-
- n_inner_group = world_size // inner_sequence_parallel_size
-
- global _INNER_SEQUENCE_PARALLEL_GROUP
- assert _INNER_SEQUENCE_PARALLEL_GROUP is None
-
- for i in range(n_inner_group):
- ranks = range(i * inner_sequence_parallel_size,
- (i + 1) * inner_sequence_parallel_size)
- group = dist.new_group(ranks)
- if rank in ranks:
- _INNER_SEQUENCE_PARALLEL_GROUP = group
-
-
-def is_inner_sequence_parallel_initialized():
- return _INNER_SEQUENCE_PARALLEL_GROUP is not None
-
-
-def get_inner_sequence_parallel_group():
- return _INNER_SEQUENCE_PARALLEL_GROUP
-
-
-def get_inner_sequence_parallel_world_size():
- global _INNER_SEQUENCE_PARALLEL_WORLD_SIZE
- if _INNER_SEQUENCE_PARALLEL_WORLD_SIZE is not None:
- return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE
- if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None):
- _INNER_SEQUENCE_PARALLEL_WORLD_SIZE = 1
- else:
- _INNER_SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
- group=get_inner_sequence_parallel_group())
- return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE
-
-
-def get_inner_sequence_parallel_rank():
- global _INNER_SEQUENCE_PARALLEL_RANK
- if _INNER_SEQUENCE_PARALLEL_RANK is not None:
- return _INNER_SEQUENCE_PARALLEL_RANK
- if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None):
- _INNER_SEQUENCE_PARALLEL_RANK = 0
- else:
- _INNER_SEQUENCE_PARALLEL_RANK = dist.get_rank(
- group=get_inner_sequence_parallel_group())
- return _INNER_SEQUENCE_PARALLEL_RANK
-
-
-def get_sequence_parallel_group():
- """Get the sequence parallel group the caller rank belongs to."""
- return _SEQUENCE_PARALLEL_GROUP
-
-
-def get_sequence_parallel_world_size():
- """Return world size for the sequence parallel group."""
- global _SEQUENCE_PARALLEL_WORLD_SIZE
- if _SEQUENCE_PARALLEL_WORLD_SIZE is not None:
- return _SEQUENCE_PARALLEL_WORLD_SIZE
- if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None):
- _SEQUENCE_PARALLEL_WORLD_SIZE = 1
- else:
- _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
- group=get_sequence_parallel_group())
- return _SEQUENCE_PARALLEL_WORLD_SIZE
-
-
-def get_sequence_parallel_rank():
- """Return my rank for the sequence parallel group."""
- global _SEQUENCE_PARALLEL_RANK
- if _SEQUENCE_PARALLEL_RANK is not None:
- return _SEQUENCE_PARALLEL_RANK
- if not dist.is_initialized() or (_SEQUENCE_PARALLEL_GROUP is None):
- _SEQUENCE_PARALLEL_RANK = 0
- else:
- _SEQUENCE_PARALLEL_RANK = dist.get_rank(
- group=get_sequence_parallel_group())
- return _SEQUENCE_PARALLEL_RANK
-
-
-def get_data_parallel_group():
- """Get the data parallel group the caller rank belongs to."""
- assert _DATA_PARALLEL_GROUP is not None, \
- 'data parallel group is not initialized'
- return _DATA_PARALLEL_GROUP
-
-
-def get_data_parallel_world_size():
- """Return world size for the data parallel group."""
- global _DATA_PARALLEL_WORLD_SIZE
- if _DATA_PARALLEL_WORLD_SIZE is not None:
- return _DATA_PARALLEL_WORLD_SIZE
- if not dist.is_initialized():
- _DATA_PARALLEL_WORLD_SIZE = 1
- else:
- _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size(
- group=get_data_parallel_group())
- return _DATA_PARALLEL_WORLD_SIZE
-
-
-def get_data_parallel_rank():
- """Return my rank for the data parallel group."""
- global _DATA_PARALLEL_RANK
- if _DATA_PARALLEL_RANK is not None:
- return _DATA_PARALLEL_RANK
- if not dist.is_initialized():
- _DATA_PARALLEL_RANK = 0
- else:
- _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group())
- return _DATA_PARALLEL_RANK
diff --git a/code/xtuner/registry.py b/code/xtuner/registry.py
deleted file mode 100644
index 7c8907e0be44210849d029bc26c77494971220b0..0000000000000000000000000000000000000000
--- a/code/xtuner/registry.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.registry import Registry
-
-__all__ = ['BUILDER', 'MAP_FUNC']
-
-BUILDER = Registry('builder')
-MAP_FUNC = Registry('map_fn')
diff --git a/code/xtuner/scripts/acmil_testing_script.sh b/code/xtuner/scripts/acmil_testing_script.sh
deleted file mode 100644
index 86f40053beeb0fc01aa4d2d9b174225f4a78c6e9..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/acmil_testing_script.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-CONFIG="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_acmil/stage_2_acmil_blca.py"
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for CKPT_DIR in "${OUTPUT_DIR}"/*maxlength_acmil; do
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
-
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found."
- continue
- fi
-
- # 1) Read the exact .pth path and trim whitespace
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # 2) Extract tumor code from folder name (before first '_')
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # sanity check: does the JSON exist?
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping."
- continue
- fi
-
- echo
- echo "=============================================="
- echo " Testing ${tumor} with checkpoint: ${ckpt_path}"
- echo "=============================================="
-
- EVAL_LOG="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/slidechat_acmil_all_eval.txt"
-
- out_csv="${OUTPUT_DIR}/${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | tumor=${tumor} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=0 \
- xtuner test "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/acmil_training_script.sh b/code/xtuner/scripts/acmil_training_script.sh
deleted file mode 100644
index 48e2a20acd8dc706d22685f740f0f46597cae651..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/acmil_training_script.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="6,7"
-export NPROC_PER_NODE=2
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="configs/slidechat/experiments_acmil"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Suffix for your work‐dirs
-WORKDIR_SUFFIX="original_2048maxlength_acmil"
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-for json_path in "${DATASET_DIR}"/*.json; do
- # extract e.g. "BLCA" then lowercase -> "blca"
- tumor="$(basename "${json_path}" .json)"
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- cfg="${SCRIPT_DIR}/stage_2_acmil_${tumor_lc}.py"
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_SUFFIX}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting SFT on ${tumor_lc}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc}"
- echo
-done
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/attn_testing_script.sh b/code/xtuner/scripts/attn_testing_script.sh
deleted file mode 100644
index c717cbea09f497fc72267d4ac053602aba0e951a..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/attn_testing_script.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-CONFIG="configs/slidechat/stage_2_reducer_attn.py"
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for CKPT_DIR in "${OUTPUT_DIR}"/*attn; do
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
-
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found."
- continue
- fi
-
- # 1) Read the exact .pth path and trim whitespace
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # 2) Extract tumor code from folder name (before first '_')
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # sanity check: does the JSON exist?
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping."
- continue
- fi
-
- echo
- echo "=============================================="
- echo " Testing ${tumor} with checkpoint: ${ckpt_path}"
- echo "=============================================="
-
- EVAL_LOG="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/slidechat_reducer_attn_trained_all_eval2.txt"
-
- out_csv="${OUTPUT_DIR}/output_${tumor_lc}_${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | tumor=${tumor} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=2 \
- xtuner test "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/attn_training_script.sh b/code/xtuner/scripts/attn_training_script.sh
deleted file mode 100644
index 84f5f72b9c13a513d68656ed49f906f9c859e17c..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/attn_training_script.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="4,5,6,7"
-export NPROC_PER_NODE=4
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="configs/slidechat/experiments"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Suffix for your work‐dirs
-WORKDIR_SUFFIX="original_2048maxlength_train_token_reducer_attn"
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-for json_path in "${DATASET_DIR}"/*.json; do
- # extract e.g. "BLCA" then lowercase -> "blca"
- tumor="$(basename "${json_path}" .json)"
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- cfg="${SCRIPT_DIR}/stage2_reducer_attn_${tumor_lc}.py"
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_SUFFIX}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting SFT on ${tumor_lc}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc}"
- echo
-done
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/baseline_testing.sh b/code/xtuner/scripts/baseline_testing.sh
deleted file mode 100644
index a04a2b5d81bdf8f03a1d7c686f9a19183368726a..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/baseline_testing.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-# path to your stage_2_divprune config
-CONFIG="configs/slidechat/stage_2.py"
-
-# path to your model checkpoint (no “.pth” extension)
-CHECKPOINT="/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth"
-
-# path to the single-slide CSV
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-
-# directory containing BLCA.json, COAD.json, … SKCM.json
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-eval_txt="/data/qingq/PathVLM/baselines/github/SlideChat/outputs/slidechat_baseline_2000_visual_tokens.txt"
-
-# where to dump all your outputs
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for tumor_json in "${TUMOR_DIR}"/*.json; do
- # extract “BLCA” from “BLCA.json”
- tumor=$(basename "${tumor_json}" .json)
- # lowercase for file names: “blca”
- tumor_lc=$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')
-
- out_csv="${OUTPUT_DIR}/output_${tumor_lc}_baseline_2000_visual_tokens.csv"
-
- echo "=== Testing tumor_type=${tumor} → ${out_csv}, ${eval_txt} ==="
- CUDA_VISIBLE_DEVICES=2 \
- xtuner test "${CONFIG}" \
- --checkpoint "${CHECKPOINT}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${eval_txt}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/compressor_testing_script.sh b/code/xtuner/scripts/compressor_testing_script.sh
deleted file mode 100644
index 943426c70375e05b0d10097f9b4f7e422202d16a..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/compressor_testing_script.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-CONFIG="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_token_compressor/stage2_token_compressor_blca.py"
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for CKPT_DIR in "${OUTPUT_DIR}"/*compressor; do
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
-
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found."
- continue
- fi
-
- # 1) Read the exact .pth path and trim whitespace
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # 2) Extract tumor code from folder name (before first '_')
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # sanity check: does the JSON exist?
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping."
- continue
- fi
-
- echo
- echo "=============================================="
- echo " Testing ${tumor} with checkpoint: ${ckpt_path}"
- echo "=============================================="
-
- EVAL_LOG="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/slidechat_token_compressor_all_eval2.txt"
-
- out_csv="${OUTPUT_DIR}/${tumor_lc}_${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | tumor=${tumor} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=3 \
- xtuner test_token_compressor "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/compressor_training_script.sh b/code/xtuner/scripts/compressor_training_script.sh
deleted file mode 100644
index 116be1d90306ad043e95b073635e47b1b34242b8..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/compressor_training_script.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="4,5,6,7"
-export NPROC_PER_NODE=4
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="configs/slidechat/experiments_token_compressor"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Suffix for your work‐dirs
-WORKDIR_SUFFIX="original_2048maxlength_2fusion_token_compressor"
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-for json_path in "${DATASET_DIR}"/*.json; do
- # extract e.g. "BLCA" then lowercase -> "blca"
- tumor="$(basename "${json_path}" .json)"
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- cfg="${SCRIPT_DIR}/stage2_token_compressor_${tumor_lc}.py"
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_SUFFIX}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting SFT on ${tumor_lc}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc}"
- echo
-done
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/divprune_testing_script.sh b/code/xtuner/scripts/divprune_testing_script.sh
deleted file mode 100644
index f5492cd8ef92269e89a29ad54cefd62e42a1443f..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/divprune_testing_script.sh
+++ /dev/null
@@ -1,41 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-# path to your stage_2_divprune config
-CONFIG="configs/slidechat/stage_2_divprune.py"
-
-# path to your model checkpoint (no “.pth” extension)
-CHECKPOINT="/data/qingq/PathVLM/baselines/github/SlideChat/models/slidechat_weight/stage2_pth"
-
-# path to the single-slide CSV
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-
-# directory containing BLCA.json, COAD.json, … SKCM.json
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-eval_txt="/data/qingq/PathVLM/baselines/github/SlideChat/outputs/slidechat_divprune_cosine_050_eval2.txt"
-
-# where to dump all your outputs
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for tumor_json in "${TUMOR_DIR}"/*.json; do
- # extract “BLCA” from “BLCA.json”
- tumor=$(basename "${tumor_json}" .json)
- # lowercase for file names: “blca”
- tumor_lc=$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')
-
- out_csv="${OUTPUT_DIR}/output_${tumor_lc}_divprune_cosine_050.csv"
-
- echo "=== Testing tumor_type=${tumor} → ${out_csv}, ${eval_txt} ==="
- CUDA_VISIBLE_DEVICES=2 \
- xtuner test "${CONFIG}" \
- --checkpoint "${CHECKPOINT}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --divprune_ratio 0.5 \
- --eval_output_path "${eval_txt}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/dynamic_training_script.sh b/code/xtuner/scripts/dynamic_training_script.sh
deleted file mode 100644
index a59c217ca41987ac27cd11472082e93005beec82..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/dynamic_training_script.sh
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/bin/bash
-#SBATCH --job-name=dynamic-llava
-#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=8
-#SBATCH --gres=gpu:8
-#SBATCH --cpus-per-task=8
-#SBATCH --time=72:00:00
-#SBATCH --output=logs/dynamic_llava_%j.out
-#SBATCH --error=logs/dynamic_llava_%j.err
-
-# Environment setup
-export CUDA_VISIBLE_DEVICES=3,4,5,1,0,6,7,2
-export NPROC_PER_NODE=8
-
-# Path settings
-CONFIG_DIR="configs/slidechat/experiments_dynamic_llava"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-DS_CONFIG="configs/deepspeed/deepspeed_zero3_offload.json"
-
-# Tumor types
-TUMORS=("BRCA" "LGG" "COAD" "GBM")
-
-# Create log dir if missing
-mkdir -p logs
-
-# Loop over tumor types
-for TUMOR in "${TUMORS[@]}"; do
- CONFIG_FILE="${CONFIG_DIR}/stage_2_dynamic_llava_${TUMOR,,}.py"
- WORK_DIR="${OUTPUT_BASE}/stage2_dynamic_llava_qlora_${TUMOR,,}"
-
- echo "Launching training for tumor: $TUMOR"
- echo " Config: $CONFIG_FILE"
- echo " Work dir: $WORK_DIR"
-
- xtuner train \
- "$CONFIG_FILE" \
- --deepspeed "$DS_CONFIG" \
- --work-dir "$WORK_DIR" \
- --local_rank 0
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/fusion_compressor_missing_script.sh b/code/xtuner/scripts/fusion_compressor_missing_script.sh
deleted file mode 100644
index e5a0940aa373f8d1b67bcc13c79b1cc757d20a7b..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/fusion_compressor_missing_script.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 NPROC_PER_NODE=6 xtuner train \
- configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_skcm.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_original_2048maxlength_fusion_compressor \
- --local_rank 0
-
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 NPROC_PER_NODE=6 xtuner train \
- configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_hnsc.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/hnsc_original_2048maxlength_fusion_compressor \
- --local_rank 0
\ No newline at end of file
diff --git a/code/xtuner/scripts/fusion_compressor_testing_script.sh b/code/xtuner/scripts/fusion_compressor_testing_script.sh
deleted file mode 100644
index 429c0e233205a56aa1970d158932ef5a4aaaa7ec..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/fusion_compressor_testing_script.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-CONFIG="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_fusion_compressor/stage_2_fusion_compressor_gbm.py"
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for CKPT_DIR in "${OUTPUT_DIR}"/*2048maxlength_fusion_compressor; do
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
-
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found."
- continue
- fi
-
- # 1) Read the exact .pth path and trim whitespace
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # 2) Extract tumor code from folder name (before first '_')
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # sanity check: does the JSON exist?
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping."
- continue
- fi
-
- echo
- echo "=============================================="
- echo " Testing ${tumor} with checkpoint: ${ckpt_path}"
- echo "=============================================="
-
- EVAL_LOG="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/slidechat_text_guided_fusion_compressor_all_eval2.txt"
-
- out_csv="${OUTPUT_DIR}/${tumor_lc}_${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | tumor=${tumor} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=0 \
- xtuner test_fusion_compressor "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-done
\ No newline at end of file
diff --git a/code/xtuner/scripts/fusion_compressor_token_number_testing_script.sh b/code/xtuner/scripts/fusion_compressor_token_number_testing_script.sh
deleted file mode 100644
index 532f6beb24ea47f562ad7cbe4d2d72bf0d62ad8c..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/fusion_compressor_token_number_testing_script.sh
+++ /dev/null
@@ -1,89 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-# Root directory for your per-grid-size configs
-SCRIPT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers"
-
-# CSV listing all slides for evaluation
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-
-# Directory with per-tumor JSONs
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-
-# Where all training outputs are written
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-
-# Sizes (i.e., token numbers / grid sizes) you trained on
-GRID_SIZES=(100)
-
-# Suffix used in your training script for work-dirs
-WORKDIR_BASE_SUFFIX="fusion_compressor_token_numbers"
-
-# Combined evaluation log
-EVAL_LOG="${OUTPUT_DIR}/slidechat_text_guided_fusion_compressor_token_numbers_all_eval.txt"
-echo "# Evaluation started $(date)" > "${EVAL_LOG}"
-# ---------------------------------------
-
-for grid_size in "${GRID_SIZES[@]}"; do
- echo
- echo "===== Testing grid_size=${grid_size} =====" | tee -a "${EVAL_LOG}"
-
- # For each tumor-specific checkpoint dir matching the pattern
- for CKPT_DIR in "${OUTPUT_DIR}"/*_"${WORKDIR_BASE_SUFFIX}"_"${grid_size}"; do
- [[ -d "${CKPT_DIR}" ]] || continue
-
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found." | tee -a "${EVAL_LOG}"
- continue
- fi
-
- # Read checkpoint path
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # Derive tumor code
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # Sanity‐check JSON
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping." | tee -a "${EVAL_LOG}"
- continue
- fi
-
- # **New**: pick the matching config for this grid size + tumor
- CONFIG="${SCRIPT_DIR}/stage_2_visual_only_fusion_compressor_${tumor_lc}_${grid_size}.py"
- if [[ ! -f "${CONFIG}" ]]; then
- echo "⚠️ Config not found: ${CONFIG}, skipping." | tee -a "${EVAL_LOG}"
- continue
- fi
-
- echo
- echo "----------------------------------------------" | tee -a "${EVAL_LOG}"
- echo " Testing ${tumor} | grid_size=${grid_size} | ckpt: ${ckpt_path}" | tee -a "${EVAL_LOG}"
- echo " Config: ${CONFIG}" | tee -a "${EVAL_LOG}"
- echo "----------------------------------------------" | tee -a "${EVAL_LOG}"
-
- out_csv="${OUTPUT_DIR}/${tumor_lc}_toknums${grid_size}_${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | grid=${grid_size} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=6 \
- xtuner test_fusion_compressor "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-
- echo "✅ Done ${tumor} @ grid ${grid_size}" | tee -a "${EVAL_LOG}"
- done
-done
-
-echo
-echo "🎉 All token-number variants have been tested!" | tee -a "${EVAL_LOG}"
\ No newline at end of file
diff --git a/code/xtuner/scripts/fusion_compressor_token_number_training_script.sh b/code/xtuner/scripts/fusion_compressor_token_number_training_script.sh
deleted file mode 100644
index 585b678d9945a4dadb89f38a42895cd0322029f3..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/fusion_compressor_token_number_training_script.sh
+++ /dev/null
@@ -1,65 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
-export NPROC_PER_NODE=6
-
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_fusion_compressor_token_numbers"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Grid sizes to test (must match what was used in config generation)
-GRID_SIZES=(100 200)
-
-# Base suffix for your work-dirs
-WORKDIR_BASE_SUFFIX="fusion_compressor_token_numbers"
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP OVER TUMOR TYPES AND GRID SIZES
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# Get list of tumor types from JSON files
-tumors=()
-for json_path in "${DATASET_DIR}"/*.json; do
- tumor="$(basename "${json_path}" .json)"
- tumors+=("${tumor}")
-done
-
-# Loop through each tumor type and grid size combination
-for tumor in "${tumors[@]}"; do
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- for grid_size in "${GRID_SIZES[@]}"; do
- # Construct config filename with tumor and grid size
- cfg="${SCRIPT_DIR}/stage_2_visual_only_fusion_compressor_${tumor_lc}_${grid_size}.py"
-
- # Create workdir name with tumor and grid size
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_BASE_SUFFIX}_${grid_size}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc} with grid ${grid_size}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting training for ${tumor_lc} with grid size ${grid_size}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc} with grid size ${grid_size}"
- echo
- done
-done
-
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/fusion_compressor_training_script.sh b/code/xtuner/scripts/fusion_compressor_training_script.sh
deleted file mode 100644
index 39d9e1409f8a64ecfe51c34904e3b8e1daccbb8a..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/fusion_compressor_training_script.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="4,5,6,1,3,2"
-export NPROC_PER_NODE=6
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="configs/slidechat/experiments_fusion_compressor"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Suffix for your work‐dirs
-WORKDIR_SUFFIX="original_2048maxlength_fusion_compressor"
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-for json_path in "${DATASET_DIR}"/*.json; do
- # extract e.g. "BLCA" then lowercase -> "blca"
- tumor="$(basename "${json_path}" .json)"
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- cfg="${SCRIPT_DIR}/stage_2_fusion_compressor_${tumor_lc}.py"
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_SUFFIX}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting SFT on ${tumor_lc}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc}"
- echo
-done
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/image_only_fusion_compressor_training_script.sh b/code/xtuner/scripts/image_only_fusion_compressor_training_script.sh
deleted file mode 100644
index ff0d932fd87d6230bc972e44ab10eb4d3916a657..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/image_only_fusion_compressor_training_script.sh
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# CONFIGURATION
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-# GPUs and DDP settings
-export CUDA_VISIBLE_DEVICES="2,3,4,5,6,1"
-export NPROC_PER_NODE=6
-
-# Paths
-DATASET_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-SCRIPT_DIR="configs/slidechat/experiments_visual_only_compressor"
-DEEPSPEED_CONFIG="configs/deepspeed/deepspeed_zero2.json"
-OUTPUT_BASE="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-
-# Suffix for your work‐dirs
-WORKDIR_SUFFIX="original_2048maxlength_visual_only_fusion_compressor"
-
-# Tumor types to exclude
-EXCLUDE=("blca" "brca")
-
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-# LOOP
-# —––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
-
-for json_path in "${DATASET_DIR}"/*.json; do
- # extract e.g. "BLCA" then lowercase -> "blca"
- tumor="$(basename "${json_path}" .json)"
- tumor_lc="$(echo "${tumor}" | tr '[:upper:]' '[:lower:]')"
-
- # skip excluded tumor types
- for bad in "${EXCLUDE[@]}"; do
- if [[ "${tumor_lc}" == "${bad}" ]]; then
- echo "⏭️ Skipping excluded tumor type: ${tumor_lc}"
- continue 2
- fi
- done
-
- cfg="${SCRIPT_DIR}/stage_2_visual_only_fusion_compressor_${tumor_lc}.py"
- workdir="${OUTPUT_BASE}/${tumor_lc}_${WORKDIR_SUFFIX}"
-
- if [[ ! -f "${cfg}" ]]; then
- echo "⚠️ Config not found for ${tumor_lc}, skipping: ${cfg}"
- continue
- fi
-
- echo "🚀 Starting SFT on ${tumor_lc}"
- xtuner train \
- "${cfg}" \
- --deepspeed "${DEEPSPEED_CONFIG}" \
- --work-dir "${workdir}" \
- --local_rank 0
-
- echo "✅ Finished ${tumor_lc}"
- echo
-done
-
-echo "🎉 All jobs submitted."
\ No newline at end of file
diff --git a/code/xtuner/scripts/missed_training_script.sh b/code/xtuner/scripts/missed_training_script.sh
deleted file mode 100644
index 6e9debbe9a71fd272c0f13cb4144e857f6a3bb6e..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/missed_training_script.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-CUDA_VISIBLE_DEVICES=4,5,6,7,3,1 NPROC_PER_NODE=6 xtuner train \
- configs/slidechat/experiments_token_compressor/stage2_token_compressor_lgg.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/lgg_orignal_2048maxlength_train_token_compressor \
- --local_rank 0
-
-CUDA_VISIBLE_DEVICES=4,5,6,7,3,1 NPROC_PER_NODE=6 xtuner train \
- configs/slidechat/experiments_attn/stage2_reducer_attn_lgg.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/lgg_original_2048maxlength_train_token_reducer_attn \
- --local_rank 0
-
-CUDA_VISIBLE_DEVICES=4,5,6,7,3,1 NPROC_PER_NODE=6 xtuner train \
- configs/slidechat/experiments_attn/stage2_reducer_attn_luad.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/luad_original_2048maxlength_train_token_reducer_attn \
- --local_rank 0
\ No newline at end of file
diff --git a/code/xtuner/scripts/tflops_scripts.sh b/code/xtuner/scripts/tflops_scripts.sh
deleted file mode 100644
index 40b6a252ea1b3ca8299354fd3594355d3b929e6b..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/tflops_scripts.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-CUDA_VISIBLE_DEVICES=6 NPROC_PER_NODE=1 xtuner train \
- configs/slidechat/stage_2.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_original_2048maxlength_longnet_testing2 \
- --local_rank 0
-
-CUDA_VISIBLE_DEVICES=6 NPROC_PER_NODE=1 xtuner train \
- /data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/stage_2_fusion_compressor_500.py \
- --deepspeed configs/deepspeed/deepspeed_zero2.json \
- --work-dir /data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/skcm_original_2048maxlength_fusion_compression_testing_500 \
- --local_rank 0
\ No newline at end of file
diff --git a/code/xtuner/scripts/visual_only_fusion_compressor_testing_script.sh b/code/xtuner/scripts/visual_only_fusion_compressor_testing_script.sh
deleted file mode 100644
index 29681a6f9c898b49b6ad89fa37b007403b9f7a25..0000000000000000000000000000000000000000
--- a/code/xtuner/scripts/visual_only_fusion_compressor_testing_script.sh
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# ------------ CONFIGURATION ------------
-CONFIG="/data/qingq/PathVLM/baselines/github/SlideChat/xtuner/configs/slidechat/experiments_visual_only_compressor/stage_2_visual_only_fusion_compressor_blca.py"
-TEST_CSV="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/SlideBench-VQA-TCGA.csv"
-TUMOR_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/dataset/stage_2_vqa_by_tumor/stage2_vqa_tumor_"
-OUTPUT_DIR="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs"
-mkdir -p "${OUTPUT_DIR}"
-# ---------------------------------------
-
-for CKPT_DIR in "${OUTPUT_DIR}"/*visual_only_fusion_compressor; do
- ckpt_name=$(basename "${CKPT_DIR}")
- last_ckpt_file="${CKPT_DIR}/last_checkpoint"
-
- if [[ ! -f "${last_ckpt_file}" ]]; then
- echo "⚠️ Skipping ${ckpt_name}: no last_checkpoint found."
- continue
- fi
-
- # 1) Read the exact .pth path and trim whitespace
- ckpt_path=$(< "${last_ckpt_file}")
- ckpt_path="${ckpt_path//[[:space:]]/}"
-
- # 2) Extract tumor code from folder name (before first '_')
- tumor_lc="${ckpt_name%%_*}"
- tumor=$(echo "${tumor_lc}" | tr '[:lower:]' '[:upper:]')
-
- # sanity check: does the JSON exist?
- tumor_json="${TUMOR_DIR}/${tumor}.json"
- if [[ ! -f "${tumor_json}" ]]; then
- echo "⚠️ Tumor JSON not found for '${tumor}' (${tumor_json}), skipping."
- continue
- fi
-
- echo
- echo "=============================================="
- echo " Testing ${tumor} with checkpoint: ${ckpt_path}"
- echo "=============================================="
-
- EVAL_LOG="/data/qingq/PathVLM/baselines/github/SlideChat/models/outputs/slidechat_visual_only_fusion_compressor_all_eval.txt"
-
- out_csv="${OUTPUT_DIR}/${tumor_lc}_${ckpt_name}.csv"
-
- echo "--- ${ckpt_name} | tumor=${tumor} → ${out_csv} ---" | tee -a "${EVAL_LOG}"
- CUDA_VISIBLE_DEVICES=0 \
- xtuner test_fusion_compressor "${CONFIG}" \
- --checkpoint "${ckpt_path}" \
- --test_slide_csv "${TEST_CSV}" \
- --test_output_csv "${out_csv}" \
- --local_rank 0 \
- --tumor_type "${tumor}" \
- --eval_output_path "${EVAL_LOG}"
-done
\ No newline at end of file
diff --git a/code/xtuner/slidechat_baseline_eval.txt b/code/xtuner/slidechat_baseline_eval.txt
deleted file mode 100644
index 8710801c5c1041c5634ea7a73348932415feb1f4..0000000000000000000000000000000000000000
--- a/code/xtuner/slidechat_baseline_eval.txt
+++ /dev/null
@@ -1,173 +0,0 @@
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_lgg_text_guided_reducer_attn.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 121
- Correct : 91
- Accuracy : 75.21%
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_lgg_text_guided_reducer_attn.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 121
- Correct : 92
- Accuracy : 76.03%
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_lgg_text_guided_reducer_attn.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 121
- Correct : 94
- Accuracy : 77.69%
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_blca_random_selection.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 158
- Correct : 130
- Accuracy : 82.28%
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_reducer_attn_rephase.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 27
- Accuracy : 79.41%
- Average Generation Time : 1.7586 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_reducer_attn_rephase.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 27
- Accuracy : 79.41%
- Average Generation Time : 0.3120 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 5.6964 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices2.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 5.5581 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_no_visual_input.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 5.8198 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage1.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 5.2619 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage1_llm_only.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 5.2839 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_llm_only.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 4
- Accuracy : 11.76%
- Average Generation Time : 5.5989 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 4.2200 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 1
- Accuracy : 2.94%
- Average Generation Time : 4.1402 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id_2.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 20
- Correct : 0
- Accuracy : 0.00%
- Average Generation Time : 1.3137 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id_2.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 20
- Correct : 0
- Accuracy : 0.00%
- Average Generation Time : 2.9803 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_beam_search_decoding.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 45
- Correct : 3
- Accuracy : 6.67%
- Average Generation Time : 6.5524 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_beam_search_decoding.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 1
- Correct : 0
- Accuracy : 0.00%
- Average Generation Time : 7.5340 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_beam_search_decoding.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 3
- Accuracy : 8.82%
- Average Generation Time : 6.4082 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id_2_decoding.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 20
- Correct : 0
- Accuracy : 0.00%
- Average Generation Time : 20.2492 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_test_no_choices_stage2_random_fixed_question_id_decoding_2.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 20
- Correct : 0
- Accuracy : 0.00%
- Average Generation Time : 20.2163 seconds
- Image Seq Len (avg/min/max) : 11482.2/3095/23567
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_divprune_025.csv
-Evaluation Summary:
- ---------------------
- Total Samples : 34
- Correct : 27
- Accuracy : 79.41%
- Average Generation Time : 0.1512 seconds
-
diff --git a/code/xtuner/slidechat_eval.txt b/code/xtuner/slidechat_eval.txt
deleted file mode 100644
index 8a55597d3afebe05a4af92e0c7d74a6cdc4a55b4..0000000000000000000000000000000000000000
--- a/code/xtuner/slidechat_eval.txt
+++ /dev/null
@@ -1,24 +0,0 @@
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_dynamic_llava_divprune_010.csv
-Evaluation Summary:
----------------------
-Total Samples : 34
-Correct : 28
-Accuracy : 82.35%
-Average Generation Time : 0.3851 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_dynamic_llava_divprune_010.csv
-Evaluation Summary:
----------------------
-Total Samples : 34
-Correct : 27
-Accuracy : 79.41%
-Average Generation Time : 0.4414 seconds
-
-/data/qingq/PathVLM/baselines/github/SlideChat/outputs/output_skcm_dynamic_llava_divprune_03.csv
-Evaluation Summary:
----------------------
-Total Samples : 34
-Correct : 29
-Accuracy : 85.29%
-Average Generation Time : 0.5843 seconds
-
diff --git a/code/xtuner/tflops.py b/code/xtuner/tflops.py
deleted file mode 100644
index 60db7239a16eb1a325951cfcf89a8ce70e58529a..0000000000000000000000000000000000000000
--- a/code/xtuner/tflops.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import torch
-from llava.model.builder import load_pretrained_model
-from llava.mm_utils import get_model_name_from_path
-from llava.eval.run_llava import eval_model
-from ptflops import get_model_complexity_info
-
-def get_llava_flops(model_path):
- """
- Calculates the FLOPs and number of parameters for a LLaVA model.
- """
- model_name = get_model_name_from_path(model_path)
- tokenizer, model, image_processor, context_len = load_pretrained_model(
- model_path=model_path,
- model_base=None,
- model_name=model_name,
- load_4bit=True, # Set to True if you are using a 4-bit quantized model
- )
-
- # Prepare a dummy input for the model.
- # The input dimensions should match what the model expects.
- # For LLaVA, the input is typically a combination of image and text.
- # The image is processed to a tensor of size (1, 3, 336, 336) for LLaVA-1.5
- image_tensor = torch.randn(1, 3, 336, 336).to(model.device)
-
- # The text input is a sequence of token IDs.
- # We will use a dummy sequence of length 512.
- input_ids = torch.randint(0, tokenizer.vocab_size, (1, 512)).to(model.device)
-
-
- # Use ptflops to get the complexity information
- # We will use the 'aten' backend which is more suitable for transformer models
- macs, params = get_model_complexity_info(
- model,
- input_res=(3, 336, 336), # A tuple representing the image resolution
- input_constructor=lambda res: {'images': image_tensor, 'input_ids': input_ids},
- as_strings=True,
- print_per_layer_stat=True,
- verbose=True,
- backend='aten'
- )
-
- print(f"Model: {model_name}")
- print(f"Computational complexity: {macs}")
- print(f"Number of parameters: {params}")
-
-
-if __name__ == "__main__":
- # Add a new argument to the argument parser to trigger FLOPs calculation,
- # or simply call the function directly with the model path.
-
- # For example:
- model_path = "liuhaotian/llava-v1.5-7b"
- get_llava_flops(model_path)
-
- # You can also integrate this into the existing argument parsing logic
- # of the run_llava.py script.
\ No newline at end of file
diff --git a/code/xtuner/tools/.DS_Store b/code/xtuner/tools/.DS_Store
deleted file mode 100644
index f8e8aa5039f2f4f251aeaa3d9dfa7f4264889b6b..0000000000000000000000000000000000000000
Binary files a/code/xtuner/tools/.DS_Store and /dev/null differ
diff --git a/code/xtuner/tools/__pycache__/train.cpython-311.pyc b/code/xtuner/tools/__pycache__/train.cpython-311.pyc
deleted file mode 100644
index 28aa115ebbae4216fcd94631fa1034a499ca3de7..0000000000000000000000000000000000000000
Binary files a/code/xtuner/tools/__pycache__/train.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/tools/__pycache__/utils.cpython-311.pyc b/code/xtuner/tools/__pycache__/utils.cpython-311.pyc
deleted file mode 100644
index 7550b32745076381644d235adeea8910e5d9ff21..0000000000000000000000000000000000000000
Binary files a/code/xtuner/tools/__pycache__/utils.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/tools/chat.py b/code/xtuner/tools/chat.py
deleted file mode 100644
index 3bddac52cdcca8c2e5ef7ac5e10ebcd444897e5f..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/chat.py
+++ /dev/null
@@ -1,491 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-import re
-import sys
-
-import torch
-from huggingface_hub import snapshot_download
-from peft import PeftModel
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-from transformers.generation.streamers import TextStreamer
-
-from xtuner.dataset.utils import expand2square, load_image
-from xtuner.model.utils import prepare_inputs_labels_for_multimodal
-from xtuner.tools.utils import get_stop_criteria
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-
-def remove_prefix(state_dict, prefix):
- new_state_dict = {}
- for key, value in state_dict.items():
- if key.startswith(prefix):
- new_key = key[len(prefix):]
- new_state_dict[new_key] = value
- else:
- new_state_dict[key] = value
- return new_state_dict
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Chat with a HF model')
- parser.add_argument(
- 'model_name_or_path', help='Hugging Face model name or path')
- adapter_group = parser.add_mutually_exclusive_group()
- adapter_group.add_argument(
- '--adapter', default=None, help='adapter name or path')
- adapter_group.add_argument(
- '--llava', default=None, help='llava name or path')
- parser.add_argument(
- '--visual-encoder', default=None, help='visual encoder name or path')
- parser.add_argument(
- '--visual-select-layer', default=-2, help='visual select layer')
- parser.add_argument('--image', default=None, help='image')
- parser.add_argument(
- '--torch-dtype',
- default='fp16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--prompt-template',
- choices=PROMPT_TEMPLATE.keys(),
- default=None,
- help='Specify a prompt template')
- system_group = parser.add_mutually_exclusive_group()
- system_group.add_argument(
- '--system', default=None, help='Specify the system text')
- system_group.add_argument(
- '--system-template',
- choices=SYSTEM_TEMPLATE.keys(),
- default=None,
- help='Specify a system template')
- parser.add_argument(
- '--bits',
- type=int,
- choices=[4, 8, None],
- default=None,
- help='LLM bits')
- parser.add_argument(
- '--bot-name', type=str, default='BOT', help='Name for Bot')
- parser.add_argument(
- '--with-plugins',
- nargs='+',
- choices=['calculate', 'solve', 'search'],
- help='Specify plugins to use')
- parser.add_argument(
- '--no-streamer', action='store_true', help='Whether to with streamer')
- parser.add_argument(
- '--lagent', action='store_true', help='Whether to use lagent')
- parser.add_argument(
- '--stop-words', nargs='+', type=str, default=[], help='Stop words')
- parser.add_argument(
- '--offload-folder',
- default=None,
- help='The folder in which to offload the model weights (or where the '
- 'model weights are already offloaded).')
- parser.add_argument(
- '--max-new-tokens',
- type=int,
- default=2048,
- help='Maximum number of new tokens allowed in generated text')
- parser.add_argument(
- '--temperature',
- type=float,
- default=0.1,
- help='The value used to modulate the next token probabilities.')
- parser.add_argument(
- '--top-k',
- type=int,
- default=40,
- help='The number of highest probability vocabulary tokens to '
- 'keep for top-k-filtering.')
- parser.add_argument(
- '--top-p',
- type=float,
- default=0.75,
- help='If set to float < 1, only the smallest set of most probable '
- 'tokens with probabilities that add up to top_p or higher are '
- 'kept for generation.')
- parser.add_argument(
- '--repetition-penalty',
- type=float,
- default=1.0,
- help='The parameter for repetition penalty. 1.0 means no penalty.')
- parser.add_argument(
- '--seed',
- type=int,
- default=0,
- help='Random seed for reproducible text generation')
- args = parser.parse_args()
- return args
-
-
-def get_input():
- """Helper function for getting input from users."""
- sentinel = '' # ends when this string is seen
- result = None
- while result is None:
- print(('\ndouble enter to end input (EXIT: exit chat, '
- 'RESET: reset history) >>> '),
- end='')
- try:
- result = '\n'.join(iter(input, sentinel))
- except UnicodeDecodeError:
- print('Invalid characters detected. Please enter again.')
- return result
-
-
-def main():
- args = parse_args()
- torch.manual_seed(args.seed)
-
- # build llm
- quantization_config = None
- load_in_8bit = False
- if args.bits == 4:
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4')
- elif args.bits == 8:
- load_in_8bit = True
- model_kwargs = {
- 'quantization_config': quantization_config,
- 'load_in_8bit': load_in_8bit,
- 'device_map': 'auto',
- 'offload_folder': args.offload_folder,
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
- if args.lagent:
- from lagent.actions import ActionExecutor, GoogleSearch
- from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN,
- ReAct, ReActProtocol)
- from lagent.llms import HFTransformerCasualLM
-
- try:
- SERPER_API_KEY = os.environ['SERPER_API_KEY']
- except Exception:
- print('Please obtain the `SERPER_API_KEY` from https://serper.dev '
- 'and set it using `export SERPER_API_KEY=xxx`.')
- sys.exit(1)
-
- model_kwargs.pop('trust_remote_code')
- llm = HFTransformerCasualLM(
- args.model_name_or_path, model_kwargs=model_kwargs)
- if args.adapter is not None:
- print(f'Loading adapter from {args.adapter}...')
- llm.model = PeftModel.from_pretrained(
- llm.model,
- args.adapter,
- offload_folder=args.offload_folder,
- trust_remote_code=True)
- search_tool = GoogleSearch(api_key=SERPER_API_KEY)
- chatbot = ReAct(
- llm=llm,
- action_executor=ActionExecutor(actions=[search_tool]),
- protocol=ReActProtocol(
- call_protocol=CALL_PROTOCOL_CN,
- force_stop=FORCE_STOP_PROMPT_CN))
- while True:
- text = get_input()
- while text.strip() == 'RESET':
- print('Log: History responses have been removed!')
- chatbot._session_history = []
- inputs = ''
- text = get_input()
- if text.strip() == 'EXIT':
- print('Log: Exit!')
- exit(0)
- response = chatbot.chat(text)
- print(response.response)
- else:
- if args.with_plugins is None:
- inner_thoughts_open = False
- calculate_open = False
- solve_open = False
- search_open = False
- else:
- assert args.prompt_template == args.system_template == 'moss_sft'
- from plugins import plugins_api
- inner_thoughts_open = True
- calculate_open = 'calculate' in args.with_plugins
- solve_open = 'solve' in args.with_plugins
- search_open = 'search' in args.with_plugins
- # pre-import for api and model preparation
- if calculate_open:
- from plugins import calculate # noqa: F401
- if solve_open:
- from plugins import solve # noqa: F401
- if search_open:
- from plugins import search # noqa: F401
- # build llm
- llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
- **model_kwargs)
- tokenizer = AutoTokenizer.from_pretrained(
- args.model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- print(f'Load LLM from {args.model_name_or_path}')
- if args.adapter is not None:
- llm = PeftModel.from_pretrained(
- llm,
- args.adapter,
- offload_folder=args.offload_folder,
- trust_remote_code=True)
- print(f'Load adapter from {args.adapter}')
- if args.llava is not None:
- llava_path = snapshot_download(
- repo_id=args.llava) if not osp.isdir(
- args.llava) else args.llava
-
- # build visual_encoder
- if 'visual_encoder' in os.listdir(llava_path):
- assert args.visual_encoder is None, (
- "Please don't specify the `--visual-encoder` since passed "
- '`--llava` contains a visual encoder!')
- visual_encoder_path = osp.join(llava_path, 'visual_encoder')
- else:
- assert args.visual_encoder is not None, (
- 'Please specify the `--visual-encoder`!')
- visual_encoder_path = args.visual_encoder
- visual_encoder = CLIPVisionModel.from_pretrained(
- visual_encoder_path,
- torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
- image_processor = CLIPImageProcessor.from_pretrained(
- visual_encoder_path)
- print(f'Load visual_encoder from {visual_encoder_path}')
-
- # load adapter
- if 'llm_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'llm_adapter')
- llm = PeftModel.from_pretrained(
- llm,
- adapter_path,
- offload_folder=args.offload_folder,
- trust_remote_code=True)
- print(f'Load LLM adapter from {args.llava}')
- if 'visual_encoder_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
- visual_encoder = PeftModel.from_pretrained(
- visual_encoder,
- adapter_path,
- offload_folder=args.offload_folder)
- print(f'Load visual_encoder adapter from {args.llava}')
-
- # build projector
- projector_path = osp.join(llava_path, 'projector')
- projector = AutoModel.from_pretrained(
- projector_path,
- torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
- trust_remote_code=True)
- print(f'Load projector from {args.llava}')
-
- projector.cuda()
- projector.eval()
- visual_encoder.cuda()
- visual_encoder.eval()
-
- llm.eval()
-
- if args.image is not None:
- image = load_image(args.image)
- image = expand2square(
- image, tuple(int(x * 255) for x in image_processor.image_mean))
- image = image_processor.preprocess(
- image, return_tensors='pt')['pixel_values'][0]
- image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
- visual_outputs = visual_encoder(image, output_hidden_states=True)
- pixel_values = projector(
- visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
-
- stop_words = args.stop_words
- sep = ''
- if args.prompt_template:
- template = PROMPT_TEMPLATE[args.prompt_template]
- stop_words += template.get('STOP_WORDS', [])
- sep = template.get('SEP', '')
- stop_criteria = get_stop_criteria(
- tokenizer=tokenizer, stop_words=stop_words)
-
- if args.no_streamer:
- streamer = None
- else:
- streamer = TextStreamer(tokenizer, skip_prompt=True)
-
- gen_config = GenerationConfig(
- max_new_tokens=args.max_new_tokens,
- do_sample=args.temperature > 0,
- temperature=args.temperature,
- top_p=args.top_p,
- top_k=args.top_k,
- repetition_penalty=args.repetition_penalty,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
-
- n_turn = 0
- inputs = ''
- while True:
- text = get_input()
- while text.strip() == 'RESET':
- print('Log: History responses have been removed!')
- n_turn = 0
- inputs = ''
- text = get_input()
- if text.strip() == 'EXIT':
- print('Log: Exit!')
- exit(0)
-
- if args.image is not None and n_turn == 0:
- text = DEFAULT_IMAGE_TOKEN + '\n' + text
-
- if args.prompt_template:
- prompt_text = ''
- template = PROMPT_TEMPLATE[args.prompt_template]
- if 'SYSTEM' in template and n_turn == 0:
- system_text = None
- if args.system_template is not None:
- system_text = SYSTEM_TEMPLATE[
- args.system_template].format(
- round=n_turn + 1, bot_name=args.bot_name)
- elif args.system is not None:
- system_text = args.system
- if system_text is not None:
- prompt_text += template['SYSTEM'].format(
- system=system_text,
- round=n_turn + 1,
- bot_name=args.bot_name)
- prompt_text += template['INSTRUCTION'].format(
- input=text, round=n_turn + 1, bot_name=args.bot_name)
- if args.prompt_template == args.system_template == 'moss_sft':
- if not inner_thoughts_open:
- prompt_text.replace('- Inner thoughts: enabled.',
- '- Inner thoughts: disabled.')
- if not calculate_open:
- prompt_text.replace(('- Calculator: enabled. API: '
- 'Calculate(expression)'),
- '- Calculator: disabled.')
- if not solve_open:
- prompt_text.replace(
- '- Equation solver: enabled. API: Solve(equation)',
- '- Equation solver: disabled.')
- if not search_open:
- prompt_text.replace(
- '- Web search: enabled. API: Search(query)',
- '- Web search: disabled.')
- else:
- prompt_text = text
- inputs += prompt_text
- if args.image is None:
- if n_turn == 0:
- ids = tokenizer.encode(inputs, return_tensors='pt')
- else:
- ids = tokenizer.encode(
- inputs, return_tensors='pt', add_special_tokens=False)
-
- if args.with_plugins is not None:
- generate_output = llm.generate(
- inputs=ids.cuda(),
- generation_config=gen_config,
- streamer=streamer,
- stopping_criteria=stop_criteria).cpu()
- generate_output_text = tokenizer.decode(
- generate_output[0][len(ids[0]):])
- if streamer is None:
- end = '' if generate_output_text[-1] == '\n' else '\n'
- print(generate_output_text, end=end)
- pattern = r'<\|Commands\|>:(.*?)'
- command_text = ', '.join(
- re.findall(pattern, generate_output_text))
- extent_text = plugins_api(
- command_text,
- calculate_open=calculate_open,
- solve_open=solve_open,
- search_open=search_open)
- end = '' if extent_text[-1] == '\n' else '\n'
- print(extent_text, end=end)
- extent_text_ids = tokenizer.encode(
- extent_text,
- return_tensors='pt',
- add_special_tokens=False)
- new_ids = torch.cat((generate_output, extent_text_ids),
- dim=1)
-
- generate_output = llm.generate(
- inputs=new_ids.cuda(),
- generation_config=gen_config,
- streamer=streamer,
- stopping_criteria=stop_criteria)
- if streamer is None:
- output_text = tokenizer.decode(
- generate_output[0][len(new_ids[0]):])
- end = '' if output_text[-1] == '\n' else '\n'
- print(output_text, end=end)
- else:
- generate_output = llm.generate(
- inputs=ids.cuda(),
- generation_config=gen_config,
- streamer=streamer,
- stopping_criteria=stop_criteria)
- if streamer is None:
- output_text = tokenizer.decode(
- generate_output[0][len(ids[0]):])
- end = '' if output_text[-1] == '\n' else '\n'
- print(output_text, end=end)
- inputs = tokenizer.decode(generate_output[0])
- else:
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0 and n_turn == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- ids.append(IMAGE_TOKEN_INDEX)
- ids = torch.tensor(ids).cuda().unsqueeze(0)
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=llm, input_ids=ids, pixel_values=pixel_values)
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=streamer,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
- if streamer is None:
- output_text = tokenizer.decode(generate_output[0])
- end = '' if output_text[-1] == '\n' else '\n'
- print(output_text, end=end)
- inputs += tokenizer.decode(generate_output[0])
- n_turn += 1
- inputs += sep
- if len(generate_output[0]) >= args.max_new_tokens:
- print(
- 'Remove the memory of history responses, since '
- f'it exceeds the length limitation {args.max_new_tokens}.')
- n_turn = 0
- inputs = ''
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/check_custom_dataset.py b/code/xtuner/tools/check_custom_dataset.py
deleted file mode 100644
index d9d005fb5b6e9f7b3b0cf964d5dd45c4acdd5a4a..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/check_custom_dataset.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-from functools import partial
-
-import numpy as np
-from datasets import DatasetDict
-from mmengine.config import Config
-
-from xtuner.dataset.utils import Packer, encode_fn
-from xtuner.registry import BUILDER
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description='Verify the correctness of the config file for the '
- 'custom dataset.')
- parser.add_argument('config', help='config file name or path.')
- args = parser.parse_args()
- return args
-
-
-def is_standard_format(dataset):
- example = next(iter(dataset))
- if 'conversation' not in example:
- return False
- conversation = example['conversation']
- if not isinstance(conversation, list):
- return False
- for item in conversation:
- if (not isinstance(item, dict)) or ('input'
- not in item) or ('output'
- not in item):
- return False
- input, output = item['input'], item['output']
- if (not isinstance(input, str)) or (not isinstance(output, str)):
- return False
- return True
-
-
-def main():
- args = parse_args()
-
- cfg = Config.fromfile(args.config)
-
- tokenizer = BUILDER.build(cfg.tokenizer)
- if cfg.get('framework', 'mmengine').lower() == 'huggingface':
- train_dataset = cfg.train_dataset
- else:
- train_dataset = cfg.train_dataloader.dataset
-
- dataset = train_dataset.dataset
- max_length = train_dataset.max_length
- dataset_map_fn = train_dataset.get('dataset_map_fn', None)
- template_map_fn = train_dataset.get('template_map_fn', None)
- max_dataset_length = train_dataset.get('max_dataset_length', 10)
- split = train_dataset.get('split', 'train')
- remove_unused_columns = train_dataset.get('remove_unused_columns', False)
- rename_maps = train_dataset.get('rename_maps', [])
- shuffle_before_pack = train_dataset.get('shuffle_before_pack', True)
- pack_to_max_length = train_dataset.get('pack_to_max_length', True)
- input_ids_with_output = train_dataset.get('input_ids_with_output', True)
-
- if dataset.get('path', '') != 'json':
- raise ValueError(
- 'You are using custom datasets for SFT. '
- 'The custom datasets should be in json format. To load your JSON '
- 'file, you can use the following code snippet: \n'
- '"""\nfrom datasets import load_dataset \n'
- 'dataset = dict(type=load_dataset, path=\'json\', '
- 'data_files=\'your_json_file.json\')\n"""\n'
- 'For more details, please refer to Step 5 in the '
- '`Using Custom Datasets` section of the documentation found at'
- ' docs/zh_cn/user_guides/single_turn_conversation.md.')
-
- try:
- dataset = BUILDER.build(dataset)
- except RuntimeError:
- raise RuntimeError(
- 'Unable to load the custom JSON file using '
- '`datasets.load_dataset`. Your data-related config is '
- f'{train_dataset}. Please refer to the official documentation on'
- ' `load_dataset` (https://huggingface.co/docs/datasets/loading) '
- 'for more details.')
-
- if isinstance(dataset, DatasetDict):
- dataset = dataset[split]
-
- if not is_standard_format(dataset) and dataset_map_fn is None:
- raise ValueError(
- 'If the custom dataset is not in the XTuner-defined '
- 'format, please utilize `dataset_map_fn` to map the original data'
- ' to the standard format. For more details, please refer to '
- 'Step 1 and Step 5 in the `Using Custom Datasets` section of the '
- 'documentation found at '
- '`docs/zh_cn/user_guides/single_turn_conversation.md`.')
-
- if is_standard_format(dataset) and dataset_map_fn is not None:
- raise ValueError(
- 'If the custom dataset is already in the XTuner-defined format, '
- 'please set `dataset_map_fn` to None.'
- 'For more details, please refer to Step 1 and Step 5 in the '
- '`Using Custom Datasets` section of the documentation found at'
- ' docs/zh_cn/user_guides/single_turn_conversation.md.')
-
- max_dataset_length = min(max_dataset_length, len(dataset))
- indices = np.random.choice(len(dataset), max_dataset_length, replace=False)
- dataset = dataset.select(indices)
-
- if dataset_map_fn is not None:
- dataset = dataset.map(dataset_map_fn)
-
- print('#' * 20 + ' dataset after `dataset_map_fn` ' + '#' * 20)
- print(dataset[0]['conversation'])
-
- if template_map_fn is not None:
- template_map_fn = BUILDER.build(template_map_fn)
- dataset = dataset.map(template_map_fn)
-
- print('#' * 20 + ' dataset after adding templates ' + '#' * 20)
- print(dataset[0]['conversation'])
-
- for old, new in rename_maps:
- dataset = dataset.rename_column(old, new)
-
- if pack_to_max_length and (not remove_unused_columns):
- raise ValueError('We have to remove unused columns if '
- '`pack_to_max_length` is set to True.')
-
- dataset = dataset.map(
- partial(
- encode_fn,
- tokenizer=tokenizer,
- max_length=max_length,
- input_ids_with_output=input_ids_with_output),
- remove_columns=list(dataset.column_names)
- if remove_unused_columns else None)
-
- print('#' * 20 + ' encoded input_ids ' + '#' * 20)
- print(dataset[0]['input_ids'])
- print('#' * 20 + ' encoded labels ' + '#' * 20)
- print(dataset[0]['labels'])
-
- if pack_to_max_length and split == 'train':
- if shuffle_before_pack:
- dataset = dataset.shuffle()
- dataset = dataset.flatten_indices()
- dataset = dataset.map(Packer(max_length), batched=True)
-
- print('#' * 20 + ' input_ids after packed to max_length ' +
- '#' * 20)
- print(dataset[0]['input_ids'])
- print('#' * 20 + ' labels after packed to max_length ' + '#' * 20)
- print(dataset[0]['labels'])
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/copy_cfg.py b/code/xtuner/tools/copy_cfg.py
deleted file mode 100644
index 9c3ff69c1271ae16fc3ad11d2f7ce184cca5dfea..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/copy_cfg.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os.path as osp
-import shutil
-
-from mmengine.utils import mkdir_or_exist
-
-from xtuner.configs import cfgs_name_path
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('config_name', help='config name')
- parser.add_argument('save_dir', help='save directory for copied config')
- args = parser.parse_args()
- return args
-
-
-def add_copy_suffix(string):
- file_name, ext = osp.splitext(string)
- return f'{file_name}_copy{ext}'
-
-
-def main():
- args = parse_args()
- mkdir_or_exist(args.save_dir)
- config_path = cfgs_name_path[args.config_name]
- save_path = osp.join(args.save_dir,
- add_copy_suffix(osp.basename(config_path)))
- shutil.copyfile(config_path, save_path)
- print(f'Copy to {save_path}')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/data_preprocess/arxiv.py b/code/xtuner/tools/data_preprocess/arxiv.py
deleted file mode 100644
index 55c3004038971462142f1a4a3619edae4d775b34..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/data_preprocess/arxiv.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import json
-from datetime import datetime
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('src_file', help='source file path')
- parser.add_argument('dst_file', help='destination file path')
- parser.add_argument(
- '--categories',
- nargs='+',
- default=['cs.AI', 'cs.CL', 'cs.CV'],
- help='target categories')
- parser.add_argument(
- '--start-date',
- default='2020-01-01',
- help='start date (format: YYYY-MM-DD)')
-
- args = parser.parse_args()
- return args
-
-
-def has_intersection(list1, list2):
- set1 = set(list1)
- set2 = set(list2)
- return len(set1.intersection(set2)) > 0
-
-
-def read_json_file(file_path):
- data = []
- with open(file_path) as file:
- for line in file:
- try:
- json_data = json.loads(line)
- data.append(json_data)
- except json.JSONDecodeError:
- print(f'Failed to parse line: {line}')
- return data
-
-
-def main():
- args = parse_args()
- json_data = read_json_file(args.src_file)
- from_time = datetime.strptime(args.start_date, '%Y-%m-%d')
- filtered_data = [
- item for item in json_data
- if has_intersection(args.categories, item['categories'].split())
- and datetime.strptime(item['update_date'], '%Y-%m-%d') >= from_time
- ]
-
- with open(args.dst_file, 'w') as file:
- json.dump(filtered_data, file)
-
- print(f'Save to {args.dst_file}\n{len(filtered_data)} items')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/data_preprocess/convert_refcoco.py b/code/xtuner/tools/data_preprocess/convert_refcoco.py
deleted file mode 100644
index 883e82a226414f9fbf49e27ed7144bd8e478cfef..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/data_preprocess/convert_refcoco.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import json
-
-from xtuner.dataset.refcoco_json import RefCOCOJsonDataset
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--ann-path',
- default='data/refcoco_annotations',
- help='Refcoco annotation path',
- )
- parser.add_argument(
- '--image-path',
- default='data/llava_data/llava_images/coco/train2017',
- help='COCO image path',
- )
- parser.add_argument(
- '--save-path', default='./', help='The folder to save converted data')
- args = parser.parse_args()
- return args
-
-
-if __name__ == '__main__':
- args = parse_args()
-
- data_info = [
- ('refcoco', 'unc'),
- ('refcoco+', 'unc'),
- ('refcocog', 'umd'),
- ]
- all_data = []
- for dataset, split in data_info:
- data = RefCOCOJsonDataset.get_data_json(
- ann_path=args.ann_path,
- image_path=args.image_path,
- dataset=dataset,
- splitBy=split,
- )[0]
- all_data.extend(data)
- save_path = args.save_path + '/train.json'
- with open(save_path, 'w') as f:
- print(f'save to {save_path} with {len(all_data)} items.')
- print(all_data[0])
- json.dump(all_data, f, indent=4)
diff --git a/code/xtuner/tools/eval_refcoco.py b/code/xtuner/tools/eval_refcoco.py
deleted file mode 100644
index cbdc1bf6e9dda876440ffa61416f66247d1705db..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/eval_refcoco.py
+++ /dev/null
@@ -1,356 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-import re
-
-import torch
-import tqdm
-from huggingface_hub import snapshot_download
-from mmengine.dist import get_dist_info, init_dist, master_only
-from mmengine.utils.dl_utils import set_multi_processing
-from peft import PeftModel
-from torch import distributed as dist
-from torch.utils.data import DataLoader, DistributedSampler
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-
-from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
-from xtuner.dataset.refcoco_json import RefCOCOJsonEvalDataset
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from xtuner.tools.utils import get_stop_criteria
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- PROMPT_TEMPLATE)
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-
-def merge_outputs(otuputs):
- new_outputs = [None for _ in range(dist.get_world_size())]
-
- assert dist.is_initialized()
-
- dist.all_gather_object(new_outputs, otuputs)
- new_dict = []
- for output in new_outputs:
- new_dict.extend(output)
- return new_dict
-
-
-@master_only
-def master_print(msg):
- print(msg)
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='MMBench')
- parser.add_argument(
- 'model_name_or_path', help='Hugging Face model name or path')
- parser.add_argument('--data-path', default=None, help='data path')
- parser.add_argument('--work-dir', help='the dir to save results')
- parser.add_argument('--llava', default=None, help='llava name or path')
- parser.add_argument(
- '--visual-encoder', default=None, help='visual encoder name or path')
- parser.add_argument(
- '--visual-select-layer', default=-2, help='visual select layer')
- parser.add_argument(
- '--prompt-template',
- choices=PROMPT_TEMPLATE.keys(),
- default=None,
- help='Specify a prompt template',
- )
- parser.add_argument(
- '--stop-words', nargs='+', type=str, default=[], help='Stop words')
- parser.add_argument(
- '--torch-dtype',
- default='fp16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.',
- )
- parser.add_argument(
- '--bits',
- type=int,
- choices=[4, 8, None],
- default=None,
- help='LLM bits')
- parser.add_argument(
- '--bot-name', type=str, default='BOT', help='Name for Bot')
- parser.add_argument(
- '--offload-folder',
- default=None,
- help='The folder in which to offload the model weights (or where the '
- 'model weights are already offloaded).',
- )
- parser.add_argument(
- '--max-new-tokens',
- type=int,
- default=100,
- help='Maximum number of new tokens allowed in generated text',
- )
- parser.add_argument(
- '--seed',
- type=int,
- default=0,
- help='Random seed for reproducible text generation',
- )
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher',
- )
- args = parser.parse_args()
- return args
-
-
-def eval_iou(answers):
-
- def computeIoU(bbox1, bbox2):
- x1, y1, x2, y2 = bbox1
- x3, y3, x4, y4 = bbox2
- intersection_x1 = max(x1, x3)
- intersection_y1 = max(y1, y3)
- intersection_x2 = min(x2, x4)
- intersection_y2 = min(y2, y4)
- intersection_area = max(0,
- intersection_x2 - intersection_x1 + 1) * max(
- 0, intersection_y2 - intersection_y1 + 1)
- bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
- bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
- union_area = bbox1_area + bbox2_area - intersection_area
- iou = intersection_area / union_area
- return iou
-
- right = 0
- for answer in answers:
- bbox = answer['bbox']
- bbox = RefCOCOJsonEvalDataset.normalize_bbox(bbox, answer['height'],
- answer['width'])
- answer_bbox = [int(x) for x in re.findall(r'\d+', answer['ans'])]
- if len(answer_bbox) == 4:
- iou = computeIoU(answer_bbox, bbox)
- if iou > 0.5:
- right += 1
- else:
- print('Error format sample: ', answer)
- return right / len(answers)
-
-
-def build_model(args):
- rank, world_size = get_dist_info()
- # build llm
- quantization_config = None
- load_in_8bit = False
- if args.bits == 4:
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4',
- )
- elif args.bits == 8:
- load_in_8bit = True
- model_kwargs = {
- 'quantization_config': quantization_config,
- 'load_in_8bit': load_in_8bit,
- 'device_map': rank if world_size > 1 else 'auto',
- 'offload_folder': args.offload_folder,
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype],
- }
-
- # build llm
- with LoadWoInit():
- llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
- **model_kwargs)
- tokenizer = AutoTokenizer.from_pretrained(
- args.model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- master_print(f'Load LLM from {args.model_name_or_path}')
-
- llava_path = (
- snapshot_download(
- repo_id=args.llava) if not osp.isdir(args.llava) else args.llava)
-
- # build visual_encoder
- if 'visual_encoder' in os.listdir(llava_path):
- assert args.visual_encoder is None, (
- "Please don't specify the `--visual-encoder` since passed "
- '`--llava` contains a visual encoder!')
- visual_encoder_path = osp.join(llava_path, 'visual_encoder')
- else:
- assert (args.visual_encoder is not None
- ), 'Please specify the `--visual-encoder`!' # noqa: E501
- visual_encoder_path = args.visual_encoder
- with LoadWoInit():
- visual_encoder = CLIPVisionModel.from_pretrained(
- visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
- image_processor = CLIPImageProcessor.from_pretrained(
- visual_encoder_path)
- master_print(f'Load visual_encoder from {visual_encoder_path}')
-
- # load adapter
- if 'llm_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'llm_adapter')
-
- with LoadWoInit():
- llm = PeftModel.from_pretrained(
- llm, adapter_path, offload_folder=args.offload_folder)
-
- master_print(f'Load LLM adapter from {args.llava}')
-
- if 'visual_encoder_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
- visual_encoder = PeftModel.from_pretrained(
- visual_encoder, adapter_path, offload_folder=args.offload_folder)
- master_print(f'Load visual_encoder adapter from {args.llava}')
-
- # build projector
- projector_path = osp.join(llava_path, 'projector')
- with LoadWoInit():
- projector = AutoModel.from_pretrained(
- projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
- master_print(f'Load projector from {args.llava}')
-
- projector.cuda()
- projector.eval()
-
- visual_encoder.cuda()
- visual_encoder.eval()
-
- llm.eval()
- return llm, visual_encoder, projector, tokenizer, image_processor
-
-
-def generate(
- llm,
- visual_encoder,
- projector,
- tokenizer,
- samples,
- visual_select_layer,
-):
- gen_config = GenerationConfig(
- max_new_tokens=100,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=(tokenizer.pad_token_id if tokenizer.pad_token_id
- is not None else tokenizer.eos_token_id),
- )
- stop_criteria = get_stop_criteria(tokenizer=tokenizer, stop_words=[''])
-
- device = next(llm.parameters()).device
- # prepare inputs
- inputs = samples['conversation'][0]['input'][0]
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- ids.append(IMAGE_TOKEN_INDEX)
- ids = torch.tensor(ids).cuda().unsqueeze(0)
-
- visual_outputs = visual_encoder(
- samples['pixel_values'].to(device), output_hidden_states=True)
- pixel_values = projector(
- visual_outputs.hidden_states[visual_select_layer][:, 1:])
- samples['pixel_values'] = pixel_values
- samples['input_ids'] = ids
- datax = prepare_inputs_labels_for_multimodal(
- llm=llm.to(device),
- input_ids=samples['input_ids'].to(device),
- pixel_values=samples['pixel_values'].to(device),
- )
-
- # generation
- generation = llm.generate(
- **datax,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria,
- )
- answer = tokenizer.decode(generation[0])
- return {
- 'ans': answer,
- 'id': samples['id'][0],
- 'bbox': torch.tensor(samples['bbox']).tolist(),
- 'height': samples['height'],
- 'width': samples['width'],
- }
-
-
-@torch.no_grad()
-def main():
- # init
- args = parse_args()
- if args.launcher != 'none':
- set_multi_processing(distributed=True)
- init_dist(args.launcher)
-
- rank, world_size = get_dist_info()
- torch.cuda.set_device(rank)
- else:
- rank = 0
- world_size = 1
- print(f'Rank: {rank} / World size: {world_size}')
-
- # build_model
- llm, visual_encoder, projector, tokenizer, image_processor = build_model(
- args)
-
- # dataset
- dataset = RefCOCOJsonEvalDataset(
- data_path=args.data_path,
- image_folder='data/llava_data/llava_images/',
- tokenizer=tokenizer,
- image_processor=image_processor,
- max_dataset_length=None,
- dataset_map_fn=llava_map_fn,
- template_map_fn=dict(
- type=template_map_fn_factory, template=PROMPT_TEMPLATE.vicuna),
- max_length=2048,
- pad_image_to_square=False,
- )
- loader = DataLoader(
- dataset,
- batch_size=1,
- shuffle=False,
- sampler=DistributedSampler(dataset, shuffle=False, seed=0),
- )
- loader.sampler.set_epoch(0)
-
- answers = []
- for i, data in tqdm.tqdm(enumerate(loader), desc=f'Rank {rank}'):
- answer = generate(
- llm,
- visual_encoder,
- projector,
- tokenizer,
- data,
- args.visual_select_layer,
- )
- answers.append(answer)
-
- merged_outputs = merge_outputs(answers)
- acc = eval_iou(merged_outputs)
- master_print(f'Acc: {acc}')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/get_data_order.py b/code/xtuner/tools/get_data_order.py
deleted file mode 100644
index 30c23e84e7213fb518f798946da0befb1091b8c2..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/get_data_order.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--data-folder', help='Data folder')
- parser.add_argument('--save-folder', help='The folder to save data order.')
- parser.add_argument(
- '--file-type',
- default='.bin',
- help='We want to get the order of the file in this type.')
- args = parser.parse_args()
- return args
-
-
-def save_data_order(data_folder, save_folder, file_type='.bin'):
- assert os.path.exists(data_folder), f'{data_folder} does not exist.'
- triples = list(os.walk(data_folder, followlinks=True))
- data_order = []
- for root, dirs, files in triples:
- dirs.sort()
- print(f'Reading {root}...')
- for fn in sorted(files):
- if fn.endswith(file_type):
- fp = os.path.join(root, fn)
- # Using relative paths so that you can get the same result
- # on different clusters
- fp = fp.replace(data_folder, '')[1:]
- data_order.append(fp)
-
- save_path = os.path.join(save_folder, 'data_order.txt')
- with open(save_path, 'w') as f:
- for fp in data_order:
- f.write(fp + '\n')
-
-
-if __name__ == '__main__':
- args = parse_args()
- save_data_order(args.data_folder, args.save_folder, args.file_type)
diff --git a/code/xtuner/tools/list_cfg.py b/code/xtuner/tools/list_cfg.py
deleted file mode 100644
index 0062ade5714aa5b30467ab53809d245f8c142f66..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/list_cfg.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-
-from xtuner.configs import cfgs_name_path
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '-p', '--pattern', default=None, help='Pattern for fuzzy matching')
- args = parser.parse_args()
- return args
-
-
-def main(pattern=None):
- args = parse_args()
- configs_names = sorted(list(cfgs_name_path.keys()))
- print('==========================CONFIGS===========================')
- if args.pattern is not None:
- print(f'PATTERN: {args.pattern}')
- print('-------------------------------')
- for name in configs_names:
- if args.pattern is None or args.pattern.lower() in name.lower():
- print(name)
- print('=============================================================')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/list_dataset_format.py b/code/xtuner/tools/list_dataset_format.py
deleted file mode 100644
index 40d3a71f2539db6b0af2880d78c0e2710c296dfe..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/list_dataset_format.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from xtuner.dataset.map_fns import DATASET_FORMAT_MAPPING
-
-
-def main():
- dataset_format = DATASET_FORMAT_MAPPING.keys()
- print('======================DATASET_FORMAT======================')
- for format in dataset_format:
- print(format)
- print('==========================================================')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/log_dataset.py b/code/xtuner/tools/log_dataset.py
deleted file mode 100644
index 40b5e25feff74d90cff8ffeaa74fd6b103d649a9..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/log_dataset.py
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-
-from mmengine.config import Config
-
-from xtuner.registry import BUILDER
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Log processed dataset.')
- parser.add_argument('config', help='config file name or path.')
- # chose which kind of dataset style to show
- parser.add_argument(
- '--show',
- default='text',
- choices=['text', 'masked_text', 'input_ids', 'labels', 'all'],
- help='which kind of dataset style to show')
- args = parser.parse_args()
- return args
-
-
-def main():
- args = parse_args()
-
- cfg = Config.fromfile(args.config)
-
- tokenizer = BUILDER.build(cfg.tokenizer)
- if cfg.get('framework', 'mmengine').lower() == 'huggingface':
- train_dataset = BUILDER.build(cfg.train_dataset)
- else:
- train_dataset = BUILDER.build(cfg.train_dataloader.dataset)
-
- if args.show == 'text' or args.show == 'all':
- print('#' * 20 + ' text ' + '#' * 20)
- print(tokenizer.decode(train_dataset[0]['input_ids']))
- if args.show == 'masked_text' or args.show == 'all':
- print('#' * 20 + ' text(masked) ' + '#' * 20)
- masked_text = ' '.join(
- ['[-100]' for i in train_dataset[0]['labels'] if i == -100])
- unmasked_text = tokenizer.decode(
- [i for i in train_dataset[0]['labels'] if i != -100])
- print(masked_text + ' ' + unmasked_text)
- if args.show == 'input_ids' or args.show == 'all':
- print('#' * 20 + ' input_ids ' + '#' * 20)
- print(train_dataset[0]['input_ids'])
- if args.show == 'labels' or args.show == 'all':
- print('#' * 20 + ' labels ' + '#' * 20)
- print(train_dataset[0]['labels'])
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/mmbench.py b/code/xtuner/tools/mmbench.py
deleted file mode 100644
index 24d3825bb2ded3be9b11aaee18f312e86342223e..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/mmbench.py
+++ /dev/null
@@ -1,513 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import json
-import math
-import os
-import os.path as osp
-import re
-import string
-import time
-
-import numpy as np
-import pandas as pd
-import torch
-import tqdm
-from huggingface_hub import snapshot_download
-from mmengine import mkdir_or_exist
-from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
- master_only)
-from mmengine.utils.dl_utils import set_multi_processing
-from peft import PeftModel
-from rich.console import Console
-from rich.table import Table
-from torch.utils.data import Dataset
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-
-from xtuner.dataset.utils import decode_base64_to_image, expand2square
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from xtuner.tools.utils import get_stop_criteria, is_cn_string
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- PROMPT_TEMPLATE)
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='MMBench')
- parser.add_argument(
- 'model_name_or_path', help='Hugging Face model name or path')
- parser.add_argument('--data-path', default=None, help='data path')
- parser.add_argument('--work-dir', help='the dir to save results')
- parser.add_argument('--llava', default=None, help='llava name or path')
- parser.add_argument(
- '--visual-encoder', default=None, help='visual encoder name or path')
- parser.add_argument(
- '--visual-select-layer', default=-2, help='visual select layer')
- parser.add_argument(
- '--prompt-template',
- choices=PROMPT_TEMPLATE.keys(),
- default=None,
- help='Specify a prompt template')
- parser.add_argument(
- '--stop-words', nargs='+', type=str, default=[], help='Stop words')
- parser.add_argument(
- '--torch-dtype',
- default='fp16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--bits',
- type=int,
- choices=[4, 8, None],
- default=None,
- help='LLM bits')
- parser.add_argument(
- '--bot-name', type=str, default='BOT', help='Name for Bot')
- parser.add_argument(
- '--offload-folder',
- default=None,
- help='The folder in which to offload the model weights (or where the '
- 'model weights are already offloaded).')
- parser.add_argument(
- '--max-new-tokens',
- type=int,
- default=100,
- help='Maximum number of new tokens allowed in generated text')
- parser.add_argument(
- '--seed',
- type=int,
- default=0,
- help='Random seed for reproducible text generation')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- args = parser.parse_args()
- return args
-
-
-@master_only
-def master_print(msg):
- print(msg)
-
-
-class MMBenchDataset(Dataset):
- ABBRS = {
- 'coarse_perception': 'CP',
- 'finegrained_perception (instance-level)': 'FP-S',
- 'finegrained_perception (cross-instance)': 'FP-C',
- 'logic_reasoning': 'LR',
- 'relation_reasoning': 'RR',
- 'attribute_reasoning': 'AR',
- 'sketch_reasoning': 'Sketch Reasoning',
- 'scenery_building': 'Scenery & Building',
- 'food_clothes': 'Food & Clothes',
- 'historical_figure': 'Historical Figure',
- 'traditional_show': 'Traditional Show',
- 'calligraphy_painting': 'Calligraphy Painting',
- 'cultural_relic': 'Cultural Relic'
- }
-
- def __init__(self, data_file):
- self.data_file = data_file
- self.df = pd.read_csv(data_file, sep='\t')
- self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test'
- self.has_l2_category = 'l2-category' in self.df.columns.to_list()
-
- def get_image(self, image):
- while len(image) < 16:
- image = self.df[self.df['index'] == int(image)]['image'].values
- assert len(image) == 1
- image = image[0]
- image = decode_base64_to_image(image)
- return image
-
- def __len__(self):
- return len(self.df)
-
- def __getitem__(self, idx):
- index = self.df.iloc[idx]['index']
- image = self.df.iloc[idx]['image']
- image = self.get_image(image)
- question = self.df.iloc[idx]['question']
- answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
- 0].keys() else None
- category = self.df.iloc[idx]['category']
-
- options = {
- cand: self.load_from_df(idx, cand)
- for cand in string.ascii_uppercase
- if self.load_from_df(idx, cand) is not None
- }
- options_prompt = ''
- for key, item in options.items():
- options_prompt += f'{key}. {item}\n'
-
- hint = self.load_from_df(idx, 'hint')
- data = {
- 'img': image,
- 'question': question,
- 'answer': answer,
- 'options': options_prompt,
- 'category': category,
- 'options_dict': options,
- 'index': index,
- 'context': hint,
- }
- if self.has_l2_category:
- data.update({'l2-category': self.df.iloc[idx]['l2-category']})
- return data
-
- def load_from_df(self, idx, key):
- if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
- return self.df.iloc[idx][key]
- else:
- return None
-
- @master_only
- def eval_result(self, result_df, show=True):
-
- def calc_acc(df, group='category'):
- assert group in ['overall', 'category', 'l2-category']
- if group == 'overall':
- res = {'Average': np.mean(df['hit'])}
- else:
- res = {}
- abilities = list(set(df[group]))
- abilities.sort()
- for ab in abilities:
- sub_df = df[df[group] == ab]
- ab = self.ABBRS[ab] if ab in self.ABBRS else ab
- res[ab] = np.mean(sub_df['hit'])
- return res
-
- def eval_sub_data(sub_data, answer_map):
- lt = len(sub_data)
- for i in range(lt):
- item = sub_data.iloc[i]
- match = re.search(r'([A-D]+)', item['prediction'])
- pred = match.group(1) if match else ''
- gt = answer_map[item['index']]
- if gt != pred:
- return 0
- return 1
-
- def show_result(ret_json):
- show_dict = ret_json.copy()
- table = Table(title=f' MMBench ({self.data_file}) ')
- console = Console()
- table.add_column('Category', justify='left')
- table.add_column('Accuracy (%)', justify='right')
- average = show_dict.pop('Average') * 100
- table.add_row('Average', f'{average:.1f}')
- table.add_section()
- for cat_name, cat_acc in show_dict.items():
- table.add_row(cat_name, f'{cat_acc * 100:.1f}')
- with console.capture() as capture:
- console.print(table, end='')
- print('\n' + capture.get())
- print('Note: Please be cautious if you use the results in papers, '
- "since we don't use ChatGPT as a helper for choice "
- 'extraction')
-
- data = result_df.sort_values(by='index')
- data['prediction'] = [str(x) for x in data['prediction']]
- for k in data.keys():
- data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
-
- data_main = data[data['index'] < int(1e6)]
- cate_map = {
- i: c
- for i, c in zip(self.df['index'], self.df['category'])
- }
- if self.has_l2_category:
- l2_cate_map = {
- i: c
- for i, c in zip(self.df['index'], self.df['l2-category'])
- }
- answer_map = {
- i: c
- for i, c in zip(self.df['index'], self.df['answer'])
- }
-
- lt = len(data_main)
- hit, tot = 0, 0
- result = {}
- for i in range(lt):
- item_main = data_main.iloc[i]
- idx = item_main['index']
- assert idx not in result
- sub_data = data[data['index'] % int(1e6) == idx]
- ret = eval_sub_data(sub_data, answer_map)
- result[idx] = ret
- hit += ret
- tot += 1
-
- indices = data_main['index']
- data_main = data_main.copy()
- data_main['hit'] = [result[i] for i in indices]
- main_idx = data_main['index']
- data_main['category'] = [cate_map[i] for i in main_idx]
-
- ret_json = calc_acc(data_main, 'overall')
-
- if self.has_l2_category:
- data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
- l2 = calc_acc(data_main, 'l2-category')
- ret_json.update(l2)
- else:
- leaf = calc_acc(data_main, 'category')
- ret_json.update(leaf)
- if show:
- show_result(ret_json)
- return ret_json
-
-
-def main():
- args = parse_args()
-
- torch.manual_seed(args.seed)
-
- if args.launcher != 'none':
- set_multi_processing(distributed=True)
- init_dist(args.launcher)
-
- rank, world_size = get_dist_info()
- torch.cuda.set_device(rank)
- else:
- rank = 0
- world_size = 1
-
- # build llm
- quantization_config = None
- load_in_8bit = False
- if args.bits == 4:
- quantization_config = BitsAndBytesConfig(
- load_in_4bit=True,
- load_in_8bit=False,
- llm_int8_threshold=6.0,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=torch.float16,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type='nf4')
- elif args.bits == 8:
- load_in_8bit = True
- model_kwargs = {
- 'quantization_config': quantization_config,
- 'load_in_8bit': load_in_8bit,
- 'device_map': rank if world_size > 1 else 'auto',
- 'offload_folder': args.offload_folder,
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
-
- # build llm
- with LoadWoInit():
- llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
- **model_kwargs)
- tokenizer = AutoTokenizer.from_pretrained(
- args.model_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
- master_print(f'Load LLM from {args.model_name_or_path}')
-
- llava_path = snapshot_download(
- repo_id=args.llava) if not osp.isdir(args.llava) else args.llava
-
- # build visual_encoder
- if 'visual_encoder' in os.listdir(llava_path):
- assert args.visual_encoder is None, (
- "Please don't specify the `--visual-encoder` since passed "
- '`--llava` contains a visual encoder!')
- visual_encoder_path = osp.join(llava_path, 'visual_encoder')
- else:
- assert args.visual_encoder is not None, (
- 'Please specify the `--visual-encoder`!')
- visual_encoder_path = args.visual_encoder
- with LoadWoInit():
- visual_encoder = CLIPVisionModel.from_pretrained(
- visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
- image_processor = CLIPImageProcessor.from_pretrained(
- visual_encoder_path)
- master_print(f'Load visual_encoder from {visual_encoder_path}')
-
- # load adapter
- if 'llm_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'llm_adapter')
-
- with LoadWoInit():
- llm = PeftModel.from_pretrained(
- llm, adapter_path, offload_folder=args.offload_folder)
-
- master_print(f'Load LLM adapter from {args.llava}')
-
- if 'visual_encoder_adapter' in os.listdir(llava_path):
- adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
- visual_encoder = PeftModel.from_pretrained(
- visual_encoder, adapter_path, offload_folder=args.offload_folder)
- master_print(f'Load visual_encoder adapter from {args.llava}')
-
- # build projector
- projector_path = osp.join(llava_path, 'projector')
- with LoadWoInit():
- projector = AutoModel.from_pretrained(
- projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
- master_print(f'Load projector from {args.llava}')
-
- projector.cuda()
- projector.eval()
-
- visual_encoder.cuda()
- visual_encoder.eval()
-
- llm.eval()
-
- stop_words = args.stop_words
- if args.prompt_template:
- template = PROMPT_TEMPLATE[args.prompt_template]
- stop_words += template.get('STOP_WORDS', [])
- stop_criteria = get_stop_criteria(
- tokenizer=tokenizer, stop_words=stop_words)
-
- gen_config = GenerationConfig(
- max_new_tokens=args.max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
-
- # work_dir
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- save_dir = args.work_dir
- else:
- # use config filename as default work_dir
- save_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.data_path))[0])
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
- save_dir = osp.join(save_dir, timestamp)
-
- if rank == 0:
- mkdir_or_exist(osp.abspath(save_dir))
- print('=======================================================')
- print(f'Dataset path: {osp.abspath(args.data_path)}\n'
- f'Results will be saved to {osp.abspath(save_dir)}')
- print('=======================================================')
-
- args_path = osp.join(save_dir, 'args.json')
- with open(args_path, 'w', encoding='utf-8') as f:
- json.dump(args.__dict__, f, indent=2)
-
- results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
- results_json_path = osp.join(save_dir, 'mmbench_result.json')
-
- dataset = MMBenchDataset(args.data_path)
-
- results = []
- n_samples = len(dataset)
- per_rank_samples = math.ceil(n_samples / world_size)
-
- per_rank_ids = range(per_rank_samples * rank,
- min(n_samples, per_rank_samples * (rank + 1)))
- for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
- data_sample = dataset[i]
- if data_sample['context'] is not None:
- text = data_sample['context'] + '\n' + data_sample[
- 'question'] + '\n' + data_sample['options']
- else:
- text = data_sample['question'] + '\n' + data_sample['options']
-
- text = DEFAULT_IMAGE_TOKEN + '\n' + text
-
- if is_cn_string(text):
- text = text + '请直接回答选项字母。'
- else:
- text = text + ("Answer with the option's letter from the "
- 'given choices directly.')
-
- if args.prompt_template:
- prompt_text = ''
- template = PROMPT_TEMPLATE[args.prompt_template]
- prompt_text += template['INSTRUCTION'].format(
- input=text, round=1, bot_name=args.bot_name)
- else:
- prompt_text = text
- inputs = prompt_text
-
- image = data_sample['img'].convert('RGB')
- image = expand2square(
- image, tuple(int(x * 255) for x in image_processor.image_mean))
- image = image_processor.preprocess(
- image, return_tensors='pt')['pixel_values'][0]
- image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
- visual_outputs = visual_encoder(image, output_hidden_states=True)
- pixel_values = projector(
- visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
-
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
-
- # TODO: Auto-detect whether to prepend a bos_token_id at the beginning.
- ids = []
-
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- ids.append(IMAGE_TOKEN_INDEX)
- ids = torch.tensor(ids).cuda().unsqueeze(0)
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=llm, input_ids=ids, pixel_values=pixel_values)
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- predict = tokenizer.decode(
- generate_output[0], skip_special_tokens=True).strip()
- cur_result = {}
- cur_result['question'] = data_sample.get('question')
- cur_result.update(data_sample.get('options_dict'))
- cur_result['prediction'] = predict
- if data_sample.get('category') is not None:
- cur_result['category'] = data_sample.get('category')
- if data_sample.get('l2-category') is not None:
- cur_result['l2-category'] = data_sample.get('l2-category')
- cur_result['index'] = data_sample.get('index')
- cur_result['split'] = data_sample.get('split')
- cur_result['answer'] = data_sample.get('answer')
- results.append(cur_result)
-
- results = collect_results(results, n_samples)
-
- if get_rank() == 0:
-
- results_df = pd.DataFrame(results)
- with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
- results_df.to_excel(writer, index=False)
-
- if dataset.split == 'dev':
- results_dict = dataset.eval_result(results_df, show=True)
- with open(results_json_path, 'w', encoding='utf-8') as f:
- json.dump(results_dict, f, indent=2)
- else:
- print('All done!')
-
-
-if __name__ == '__main__':
-
- main()
diff --git a/code/xtuner/tools/model_converters/__pycache__/pth_to_hf.cpython-311.pyc b/code/xtuner/tools/model_converters/__pycache__/pth_to_hf.cpython-311.pyc
deleted file mode 100644
index 605dfad3c98a8aef416d40a292768ec75b5eb457..0000000000000000000000000000000000000000
Binary files a/code/xtuner/tools/model_converters/__pycache__/pth_to_hf.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/tools/model_converters/merge.py b/code/xtuner/tools/model_converters/merge.py
deleted file mode 100644
index c7202a6633aa4f42e4082c81048a0053fd9e64c6..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/model_converters/merge.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-
-import torch
-from peft import PeftModel
-from transformers import (AutoModelForCausalLM, AutoTokenizer,
- CLIPImageProcessor, CLIPVisionModel)
-
-from xtuner.model.utils import LoadWoInit
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description='Merge a HuggingFace adapter to base model')
- parser.add_argument('model_name_or_path', help='model name or path')
- parser.add_argument('adapter_name_or_path', help='adapter name or path')
- parser.add_argument(
- 'save_dir', help='the directory to save the merged model')
- parser.add_argument(
- '--max-shard-size',
- type=str,
- default='2GB',
- help='Only applicable for LLM. The maximum size for '
- 'each sharded checkpoint.')
- parser.add_argument(
- '--is-clip',
- action='store_true',
- help='Indicate if the model is a clip model')
- parser.add_argument(
- '--safe-serialization',
- action='store_true',
- help='Indicate if using `safe_serialization`')
- parser.add_argument(
- '--device',
- default='cuda',
- choices=('cuda', 'cpu', 'auto'),
- help='Indicate the device')
-
- args = parser.parse_args()
- return args
-
-
-def main():
- args = parse_args()
- if args.is_clip:
- with LoadWoInit():
- model = CLIPVisionModel.from_pretrained(
- args.model_name_or_path, device_map=args.device)
- processor = CLIPImageProcessor.from_pretrained(args.model_name_or_path)
- else:
- with LoadWoInit():
- model = AutoModelForCausalLM.from_pretrained(
- args.model_name_or_path,
- torch_dtype=torch.float16,
- low_cpu_mem_usage=True,
- device_map=args.device,
- trust_remote_code=True)
- processor = AutoTokenizer.from_pretrained(
- args.model_name_or_path, trust_remote_code=True)
- model_unmerged = PeftModel.from_pretrained(
- model,
- args.adapter_name_or_path,
- device_map=args.device,
- is_trainable=False,
- trust_remote_code=True)
- model_merged = model_unmerged.merge_and_unload()
- print(f'Saving to {args.save_dir}...')
- model_merged.save_pretrained(
- args.save_dir,
- safe_serialization=args.safe_serialization,
- max_shard_size=args.max_shard_size)
- processor.save_pretrained(args.save_dir)
- print('All done!')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/model_converters/modeling_internlm2_reward/__init__.py b/code/xtuner/tools/model_converters/modeling_internlm2_reward/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/code/xtuner/tools/model_converters/modeling_internlm2_reward/configuration_internlm2.py b/code/xtuner/tools/model_converters/modeling_internlm2_reward/configuration_internlm2.py
deleted file mode 100644
index 12fdffe28ca875049873cfd010ac59ddf68af6c2..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/model_converters/modeling_internlm2_reward/configuration_internlm2.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# coding=utf-8
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" InternLM2 model configuration"""
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
-
-
-# Modified from transformers.model.llama.configuration_llama.LlamaConfig
-class InternLM2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
- an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
- configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
-
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
-
-
- Args:
- vocab_size (`int`, *optional*, defaults to 32000):
- Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`InternLM2Model`]
- hidden_size (`int`, *optional*, defaults to 4096):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 11008):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 32):
- Number of hidden layers in the Transformer encoder.
- num_attention_heads (`int`, *optional*, defaults to 32):
- Number of attention heads for each attention layer in the Transformer encoder.
- num_key_value_heads (`int`, *optional*):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details checkout [this
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
- `num_attention_heads`.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the decoder.
- max_position_embeddings (`int`, *optional*, defaults to 2048):
- The maximum sequence length that this model might ever be used with. Typically set this to something large
- just in case (e.g., 512 or 1024 or 2048).
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-12):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- tie_word_embeddings(`bool`, *optional*, defaults to `False`):
- Whether to tie weight embeddings
- Example:
-
- """
- model_type = "internlm2"
- _auto_class = "AutoConfig"
-
- def __init__( # pylint: disable=W0102
- self,
- vocab_size=103168,
- hidden_size=4096,
- intermediate_size=11008,
- num_hidden_layers=32,
- num_attention_heads=32,
- num_key_value_heads=None,
- hidden_act="silu",
- max_position_embeddings=2048,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- bos_token_id=1,
- eos_token_id=2,
- reward_token_id=92527,
- tie_word_embeddings=False,
- bias=True,
- rope_theta=10000,
- rope_scaling=None,
- attn_implementation="eager",
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.bias = bias
-
- if num_key_value_heads is None:
- num_key_value_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
-
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
- self._rope_scaling_validation()
-
- self.attn_implementation = attn_implementation
- if self.attn_implementation is None:
- self.attn_implementation = "eager"
-
- self.reward_token_id = reward_token_id
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
-
- def _rope_scaling_validation(self):
- """
- Validate the `rope_scaling` configuration.
- """
- if self.rope_scaling is None:
- return
-
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
- raise ValueError(
- "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
- f"got {self.rope_scaling}"
- )
- rope_scaling_type = self.rope_scaling.get("type", None)
- rope_scaling_factor = self.rope_scaling.get("factor", None)
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
- raise ValueError(
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
- )
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
- raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
diff --git a/code/xtuner/tools/model_converters/modeling_internlm2_reward/modeling_internlm2.py b/code/xtuner/tools/model_converters/modeling_internlm2_reward/modeling_internlm2.py
deleted file mode 100644
index 59cba84567a2c6871bdf45d12a0753a663ea87dc..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/model_converters/modeling_internlm2_reward/modeling_internlm2.py
+++ /dev/null
@@ -1,1578 +0,0 @@
-# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" PyTorch InternLM2 model."""
-import math
-import queue
-import threading
-import warnings
-from typing import List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from einops import rearrange
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
-from transformers.modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast,
-)
-from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
-)
-
-try:
- from transformers.generation.streamers import BaseStreamer
-except: # noqa # pylint: disable=bare-except
- BaseStreamer = None
-
-from .configuration_internlm2 import InternLM2Config
-
-logger = logging.get_logger(__name__)
-
-_CONFIG_FOR_DOC = "InternLM2Config"
-
-flash_attn_func, flash_attn_varlen_func = None, None
-pad_input, index_first_axis, unpad_input = None, None, None
-def _import_flash_attn():
- global flash_attn_func, flash_attn_varlen_func
- global pad_input, index_first_axis, unpad_input
- try:
- from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
- flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
- pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
- except ImportError:
- raise ImportError("flash_attn is not installed.")
-
-# Copied from transformers.models.llama.modeling_llama._get_unpad_data
-def _get_unpad_data(attention_mask):
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
-
-
-# Copied from transformers.models.bart.modeling_bart._make_causal_mask
-def _make_causal_mask(
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
-):
- """
- Make causal mask used for bi-directional self-attention.
- """
- bsz, tgt_len = input_ids_shape
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
-
- if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
-
-
-# Copied from transformers.models.bart.modeling_bart._expand_mask
-def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
-class InternLM2RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- InternLM2RMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
-class InternLM2RotaryEmbedding(nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
-
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
- )
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
-
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
-
-
-# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
-class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
-
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
- self.scaling_factor = scaling_factor
- super().__init__(dim, max_position_embeddings, base, device)
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
- t = t / self.scaling_factor
-
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
-
-
-# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
-class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
- """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
- Credits to the Reddit users /u/bloc97 and /u/emozilla.
- """
-
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
- self.scaling_factor = scaling_factor
- super().__init__(dim, max_position_embeddings, base, device)
-
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
-
- if seq_len > self.max_position_embeddings:
- base = self.base * (
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
- ) ** (self.dim / (self.dim - 2))
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
-
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
-
-
-# Copied from transformers.model.llama.modeling_llama.rotate_half
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors."""
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class InternLM2MLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- self.act_fn = ACT2FN[config.hidden_act]
-
- def forward(self, x):
- down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
-
- return down_proj
-
-
-# Copied from transformers.model.llama.modeling_llama.repeat_kv
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-# Modified from transformers.model.llama.modeling_llama.LlamaAttention
-class InternLM2Attention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
-
- def __init__(self, config: InternLM2Config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.is_causal = True
-
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
-
- self.wqkv = nn.Linear(
- self.hidden_size,
- (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
- bias=config.bias,
- )
-
- self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
- self._init_rope()
-
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = InternLM2RotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.config.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling["type"]
- scaling_factor = self.config.rope_scaling["factor"]
- if scaling_type == "dynamic":
- self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.config.rope_theta,
- scaling_factor=scaling_factor,
- )
- elif scaling_type == "linear":
- self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.config.rope_theta,
- scaling_factor=scaling_factor,
- )
- else:
- raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
- return self.rotary_emb
-
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
- "Please make sure use `attention_mask` instead.`"
- )
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- "b q (h gs d) -> b q h gs d",
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., : self.num_key_value_groups, :]
- query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
-
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights + attention_mask
-
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- attn_output = torch.matmul(attn_weights, value_states)
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
-class InternLM2FlashAttention2(InternLM2Attention):
- """
- InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # InternLM2FlashAttention2 attention does not support output_attentions
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
- "Please make sure use `attention_mask` instead.`"
- )
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop("padding_mask")
-
- output_attentions = False
-
- bsz, q_len, _ = hidden_states.size()
-
- qkv_states = self.wqkv(hidden_states)
-
- qkv_states = rearrange(
- qkv_states,
- "b q (h gs d) -> b q h gs d",
- gs=2 + self.num_key_value_groups,
- d=self.head_dim,
- )
-
- query_states = qkv_states[..., : self.num_key_value_groups, :]
- query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
- key_states = qkv_states[..., -2, :]
- value_states = qkv_states[..., -1, :]
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- attn_output = self._flash_attention_forward(
- query_states, key_states, value_states, attention_mask, q_len
- )
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
- attn_output = self.wo(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
- def _flash_attention_forward(
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
-
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`int`, *optional*):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- """
- # Contains at least one padding token in the sequence
- causal = self.is_causal and query_length != 1
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
- query_states, key_states, value_states, attention_mask, query_length
- )
-
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
-
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- )
-
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
- else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
- )
-
- return attn_output
-
- def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
-
- key_layer = index_first_axis(
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
-
- if query_length == kv_seq_len:
- query_layer = index_first_axis(
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
- )
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
-
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q.to(torch.int64),
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
-
-INTERNLM2_ATTENTION_CLASSES = {
- "eager": InternLM2Attention,
- "flash_attention_2": InternLM2FlashAttention2,
-}
-
-# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
-class InternLM2DecoderLayer(nn.Module):
- def __init__(self, config: InternLM2Config):
- super().__init__()
- self.hidden_size = config.hidden_size
-
- self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
-
- self.feed_forward = InternLM2MLP(config)
- self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- **kwargs,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*):
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
- query_sequence_length, key_sequence_length)` if default attention is used.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- if "padding_mask" in kwargs:
- warnings.warn(
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
- "Please make sure use `attention_mask` instead.`"
- )
-
- residual = hidden_states
-
- hidden_states = self.attention_norm(hidden_states)
-
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.attention(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.ffn_norm(hidden_states)
- hidden_states = self.feed_forward(hidden_states)
- hidden_states = residual + hidden_states
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights,)
-
- if use_cache:
- outputs += (present_key_value,)
-
- return outputs
-
-
-InternLM2_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`InternLM2Config`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
-@add_start_docstrings(
- "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
- InternLM2_START_DOCSTRING,
-)
-class InternLM2PreTrainedModel(PreTrainedModel):
- config_class = InternLM2Config
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["InternLM2DecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
-InternLM2_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
- when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-# Modified from transformers.model.llama.modeling_llama.LlamaModel
-@add_start_docstrings(
- "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
- InternLM2_START_DOCSTRING,
-)
-class InternLM2Model(InternLM2PreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
-
- Args:
- config: InternLM2Config
- """
-
- _auto_class = "AutoModel"
-
- def __init__(self, config: InternLM2Config):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.config = config
-
- self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
-
- self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.tok_embeddings = value
-
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _make_causal_mask(
- input_shape,
- inputs_embeds.dtype,
- device=inputs_embeds.device,
- past_key_values_length=past_key_values_length,
- )
-
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
- inputs_embeds.device
- )
- combined_attention_mask = (
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
- )
-
- return combined_attention_mask
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if self.config.attn_implementation == "flash_attention_2":
- _import_flash_attn()
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape[:2]
- elif inputs_embeds is not None:
- batch_size, seq_length = inputs_embeds.shape[:2]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
-
- seq_length_with_past = seq_length
- past_key_values_length = 0
- if past_key_values is not None:
- past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
- )
- position_ids = position_ids.unsqueeze(0)
-
- if inputs_embeds is None:
- inputs_embeds = self.tok_embeddings(input_ids)
-
- if self.config.attn_implementation == "flash_attention_2":
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
- )
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
- )
-
- # embed positions
- hidden_states = inputs_embeds
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
-
- for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
- hidden_states,
- attention_mask,
- position_ids,
- None,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
-
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
-
-
-# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
-class InternLM2ForCausalLM(InternLM2PreTrainedModel):
- _auto_class = "AutoModelForCausalLM"
-
- _tied_weights_keys = ["output.weight"]
-
- def __init__(self, config):
- super().__init__(config)
- self.model = InternLM2Model(config)
- self.vocab_size = config.vocab_size
- self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- def get_output_embeddings(self):
- return self.output
-
- def set_output_embeddings(self, new_embeddings):
- self.output = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
-
- >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- logits = self.output(hidden_states)
- logits = logits.float()
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
- ):
- if past_key_values is not None:
- past_length = past_key_values[0][0].shape[2]
-
- # Some generation methods already pass only the last input ID
- if input_ids.shape[1] > past_length:
- remove_prefix_length = past_length
- else:
- # Default to old behavior: keep only final ID
- remove_prefix_length = input_ids.shape[1] - 1
-
- input_ids = input_ids[:, remove_prefix_length:]
-
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
-
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
-
- model_inputs.update(
- {
- "position_ids": position_ids,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
- )
- return reordered_past
-
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
- if tokenizer.add_bos_token:
- prompt = ""
- else:
- prompt = tokenizer.bos_token
- if meta_instruction:
- prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
- for record in history:
- prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
- prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
- return tokenizer([prompt], return_tensors="pt")
-
- @torch.no_grad()
- def chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = [],
- streamer: Optional[BaseStreamer] = None,
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
- "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
- "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
- **kwargs,
- ):
- inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
- inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
- # also add end-of-assistant token in eos token id to avoid unnecessary generation
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
- outputs = self.generate(
- **inputs,
- streamer=streamer,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- eos_token_id=eos_token_id,
- **kwargs,
- )
- outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
- response = tokenizer.decode(outputs, skip_special_tokens=True)
- response = response.split("<|im_end|>")[0]
- history = history + [(query, response)]
- return response, history
-
- @torch.no_grad()
- def stream_chat(
- self,
- tokenizer,
- query: str,
- history: List[Tuple[str, str]] = [],
- max_new_tokens: int = 1024,
- do_sample: bool = True,
- temperature: float = 0.8,
- top_p: float = 0.8,
- **kwargs,
- ):
- """
- Return a generator in format: (response, history)
- Eg.
- ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
- ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
- """
- if BaseStreamer is None:
- raise ModuleNotFoundError(
- "The version of `transformers` is too low. Please make sure "
- "that you have installed `transformers>=4.28.0`."
- )
-
- response_queue = queue.Queue(maxsize=20)
-
- class ChatStreamer(BaseStreamer):
- def __init__(self, tokenizer) -> None:
- super().__init__()
- self.tokenizer = tokenizer
- self.queue = response_queue
- self.query = query
- self.history = history
- self.response = ""
- self.cache = []
- self.received_inputs = False
- self.queue.put((self.response, history + [(self.query, self.response)]))
-
- def put(self, value):
- if len(value.shape) > 1 and value.shape[0] > 1:
- raise ValueError("ChatStreamer only supports batch size 1")
- elif len(value.shape) > 1:
- value = value[0]
-
- if not self.received_inputs:
- # The first received value is input_ids, ignore here
- self.received_inputs = True
- return
-
- self.cache.extend(value.tolist())
- token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
- if token.strip() != "<|im_end|>":
- self.response = self.response + token
- history = self.history + [(self.query, self.response)]
- self.queue.put((self.response, history))
- self.cache = []
- else:
- self.end()
-
- def end(self):
- self.queue.put(None)
-
- def stream_producer():
- return self.chat(
- tokenizer=tokenizer,
- query=query,
- streamer=ChatStreamer(tokenizer=tokenizer),
- history=history,
- max_new_tokens=max_new_tokens,
- do_sample=do_sample,
- temperature=temperature,
- top_p=top_p,
- **kwargs,
- )
-
- def consumer():
- producer = threading.Thread(target=stream_producer)
- producer.start()
- while True:
- res = response_queue.get()
- if res is None:
- return
- yield res
-
- return consumer()
-
-# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
-class InternLM2ForRewardModel(InternLM2PreTrainedModel):
-
- _auto_class = "AutoModel"
- _tied_weights_keys = ["v_head.weight"]
-
- def __init__(self, config):
- super().__init__(config)
- self.model = InternLM2Model(config)
- self.vocab_size = config.vocab_size
- self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
- self.reward_token_id = config.reward_token_id
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- def get_output_embeddings(self):
- return self.v_head
-
- def set_output_embeddings(self, new_embeddings):
- self.v_head = new_embeddings
-
- def set_decoder(self, decoder):
- self.model = decoder
-
- def get_decoder(self):
- return self.model
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
-
- >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
-
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
- ```"""
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- hidden_states = outputs[0]
- hidden_states = self.v_head(hidden_states)
- # get end reward token's score
- ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
-
- reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends)
-
- loss = None
-
- if not return_dict:
- output = (reward_scores,) + outputs[1:]
- return (loss,) + output if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=reward_scores,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- @torch.no_grad()
- def get_score(
- self,
- tokenizer,
- conversation: List[dict],
- **kwargs,
- ):
- conversation_str = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
- input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
- # add reward score token at the end of the input_ids
- input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1).to(self.device)
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool).to(self.device)
-
- outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
- score = outputs[0].cpu().item()
- return score
-
- @torch.no_grad()
- def get_scores(
- self,
- tokenizer,
- conversations: List[List[dict]],
- **kwargs,
- ):
- conversation_strs = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversations]
- batch_input_ids = []
- attention_masks = []
-
- for conversation_str in conversation_strs:
- input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
- input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1).squeeze(0)
- attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
- batch_input_ids.append(input_ids)
- attention_masks.append(attention_mask)
-
- r_pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
- r_pad_attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)
-
- outputs = self.forward(input_ids=r_pad_batch_input_ids.to(self.device), attention_mask=r_pad_attention_masks.to(self.device), **kwargs)
- scores = outputs[0].cpu().tolist()
- return scores
-
- @torch.no_grad()
- def compare(
- self,
- tokenizer,
- conversation1: List[dict],
- conversation2: List[dict],
- return_logits: bool = False,
- **kwargs,
- ):
- score1 = self.get_score(tokenizer, conversation1, **kwargs)
- score2 = self.get_score(tokenizer, conversation2, **kwargs)
- if return_logits:
- return score1, score2
- else:
- return score1 > score2
-
- @torch.no_grad()
- def rank(
- self,
- tokenizer,
- conversations: List[List[dict]],
- return_logits: bool = False,
- **kwargs,
- ):
- scores = self.get_scores(tokenizer, conversations, **kwargs)
- if return_logits:
- return scores
- else:
- return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
-
-
-# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
-@add_start_docstrings(
- """
- The InternLM2 Model transformer with a sequence classification head on top (linear layer).
-
- [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
- as other causal models (e.g. GPT-2) do.
-
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- InternLM2_START_DOCSTRING,
-)
-class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = InternLM2Model(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.tok_embeddings
-
- def set_input_embeddings(self, value):
- self.model.tok_embeddings = value
-
- @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
- logits.device
- )
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
diff --git a/code/xtuner/tools/model_converters/pth_to_hf.py b/code/xtuner/tools/model_converters/pth_to_hf.py
deleted file mode 100644
index 76a7a2b46e3c84d02614d57f6be3f0a2f7ff1a67..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/model_converters/pth_to_hf.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os.path as osp
-import shutil
-import warnings
-
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from mmengine import print_log
-from mmengine.config import Config, DictAction
-from mmengine.fileio import PetrelBackend, get_file_backend
-from mmengine.utils import mkdir_or_exist
-from tqdm import tqdm
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import BUILDER
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description='Convert the pth model to HuggingFace model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('pth_model', help='pth model file')
- parser.add_argument(
- 'save_dir', help='the directory to save HuggingFace model')
- parser.add_argument(
- '--fp32',
- action='store_true',
- help='Save LLM in fp32. If not set, fp16 will be used by default.')
- parser.add_argument(
- '--max-shard-size',
- type=str,
- default='2GB',
- help='Only applicable for LLM. The maximum size for '
- 'each sharded checkpoint.')
- parser.add_argument(
- '--safe-serialization',
- action='store_true',
- help='Indicate if using `safe_serialization`')
- parser.add_argument(
- '--save-format',
- default='xtuner',
- choices=('xtuner', 'official', 'huggingface'),
- help='Only applicable for LLaVAModel. Indicate the save format.')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- args = parser.parse_args()
- return args
-
-
-def main():
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- model_name = cfg.model.type if isinstance(cfg.model.type,
- str) else cfg.model.type.__name__
- use_meta_init = False
-
- if 'LLaVAModel' in model_name:
- cfg.model.pretrained_pth = None
- cfg.model.llm.pop('quantization_config', None)
- if args.save_format != 'xtuner':
- use_meta_init = False
- if 'Reward' in model_name:
- use_meta_init = False
- cfg.model.llm.pop('quantization_config', None)
-
- if use_meta_init:
- try:
- # Initializing the model with meta-tensor can reduce unwanted
- # memory usage.
- with init_empty_weights():
- with warnings.catch_warnings():
- warnings.filterwarnings(
- 'ignore', message='.*non-meta.*', category=UserWarning)
- model = BUILDER.build(cfg.model)
- except NotImplementedError as e:
- # Cannot initialize the model with meta tensor if the model is
- # quantized.
- if 'Cannot copy out of meta tensor' in str(e):
- model = BUILDER.build(cfg.model)
- else:
- raise e
- else:
- model = BUILDER.build(cfg.model)
-
- backend = get_file_backend(args.pth_model)
- if isinstance(backend, PetrelBackend):
- from xtuner.utils.fileio import patch_fileio
- with patch_fileio():
- state_dict = guess_load_checkpoint(args.pth_model)
- else:
- state_dict = guess_load_checkpoint(args.pth_model)
-
- for name, param in tqdm(state_dict.items(), desc='Load State Dict'):
- set_module_tensor_to_device(model, name, 'cpu', param)
-
- model.llm.config.use_cache = True
-
- print_log(f'Load PTH model from {args.pth_model}', 'current')
-
- mkdir_or_exist(args.save_dir)
-
- save_pretrained_kwargs = {
- 'max_shard_size': args.max_shard_size,
- 'safe_serialization': args.safe_serialization
- }
- model.to_hf(
- cfg=cfg,
- save_dir=args.save_dir,
- fp32=args.fp32,
- save_pretrained_kwargs=save_pretrained_kwargs,
- save_format=args.save_format)
-
- shutil.copyfile(args.config, osp.join(args.save_dir, 'xtuner_config.py'))
- print_log('All done!', 'current')
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/model_converters/split.py b/code/xtuner/tools/model_converters/split.py
deleted file mode 100644
index da0e4d7b765a135ed8437c68befdb070da4a265a..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/model_converters/split.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import copy
-import json
-import os
-import os.path as osp
-import shutil
-
-import torch
-from mmengine.utils import mkdir_or_exist
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description='Split a HuggingFace model to the smallest sharded one')
- parser.add_argument('src_dir', help='the directory of the model')
- parser.add_argument('dst_dir', help='the directory to save the new model')
- args = parser.parse_args()
- return args
-
-
-def main():
- args = parse_args()
- mkdir_or_exist(args.dst_dir)
-
- all_files = os.listdir(args.src_dir)
- for name in all_files:
- if not name.startswith(('pytorch_model', '.')):
- src_path = osp.join(args.src_dir, name)
- dst_path = osp.join(args.dst_dir, name)
- shutil.copy(src_path, dst_path)
-
- with open(osp.join(args.src_dir, 'pytorch_model.bin.index.json')) as f:
- index = json.load(f)
-
- n_shard = len(index['weight_map'])
- new_index = copy.deepcopy(index)
- new_index['weight_map'] = {}
- cnt = 1
-
- checkpoints = set(index['weight_map'].values())
- for ckpt in checkpoints:
- state_dict = torch.load(
- osp.join(args.src_dir, ckpt), map_location='cuda')
- keys = sorted(list(state_dict.keys()))
- for k in keys:
- new_state_dict_name = 'pytorch_model-{:05d}-of-{:05d}.bin'.format(
- cnt, n_shard)
- new_index['weight_map'][k] = new_state_dict_name
- new_state_dict = {k: state_dict[k]}
- torch.save(new_state_dict,
- osp.join(args.dst_dir, new_state_dict_name))
- cnt += 1
- del state_dict
- torch.cuda.empty_cache()
- with open(osp.join(args.dst_dir, 'pytorch_model.bin.index.json'),
- 'w') as f:
- json.dump(new_index, f)
- assert new_index['weight_map'].keys() == index['weight_map'].keys(
- ), 'Mismatch on `weight_map`!'
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/plugins/__init__.py b/code/xtuner/tools/plugins/__init__.py
deleted file mode 100644
index b893bcac8976bed61f0526d57f22a118b6c6b848..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/plugins/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .api import plugins_api
-
-__all__ = ['plugins_api']
diff --git a/code/xtuner/tools/plugins/api.py b/code/xtuner/tools/plugins/api.py
deleted file mode 100644
index 7ac6579d6152564e4c7e5d885e06b39b8a03c65f..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/plugins/api.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import re
-
-
-def plugins_api(input_str,
- calculate_open=True,
- solve_open=True,
- search_open=True):
-
- pattern = r'(Solve|solve|Solver|solver|Calculate|calculate|Calculator|calculator|Search)\("([^"]*)"\)' # noqa: E501
-
- matches = re.findall(pattern, input_str)
-
- converted_str = '<|Results|>:\n'
-
- for i in range(len(matches)):
- if matches[i][0] in [
- 'Calculate', 'calculate'
- 'Calculator', 'calculator'
- ]:
- if calculate_open:
- from .calculate import Calculate
- result = Calculate(matches[i][1])
- else:
- result = None
- converted_str += f"Calculate(\"{matches[i][1]}\") => {result}\n"
- elif matches[i][0] in ['Solve', 'solve', 'Solver', 'solver']:
- if solve_open:
- from .solve import Solve
- result = Solve(matches[i][1])
- else:
- result = None
- converted_str += f"Solve(\"{matches[i][1]}\") =>\n{result}\n"
- elif matches[i][0] == 'Search':
- if search_open:
- from .search import Search
- result = Search(matches[i][1])
- else:
- result = None
- converted_str += f"Search(\"{matches[i][1]}\") =>\n{result}"
-
- converted_str += '\n'
- return converted_str
diff --git a/code/xtuner/tools/plugins/calculate.py b/code/xtuner/tools/plugins/calculate.py
deleted file mode 100644
index 48ed436cbeddd35de34fbb26d1f6f1e7d85fa810..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/plugins/calculate.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from math import * # noqa: F401, F403
-
-
-def Calculate(expression):
- res = ''
- for exp in expression.split(';'):
- try:
- res += '{:.2f};'.format(eval(exp.replace('^', '**')))
- except Exception:
- res += 'No result.'
- if res[-1] == ';':
- res = res[:-1]
- return res
diff --git a/code/xtuner/tools/plugins/search.py b/code/xtuner/tools/plugins/search.py
deleted file mode 100644
index 392bc86204fd43a7312bfd3ed13a30aef9fc4f42..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/plugins/search.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import sys
-
-import requests
-
-try:
- SERPER_API_KEY = os.environ['SERPER_API_KEY']
-except Exception:
- print('Please obtain the `SERPER_API_KEY` from https://serper.dev and '
- 'set it using `export SERPER_API_KEY=xxx`.')
- sys.exit(1)
-
-
-def parse_results(results, k=10):
- snippets = []
-
- for result in results['organic'][:k]:
- if 'snippet' in result:
- snippets.append(result['snippet'])
- for attribute, value in result.get('attributes', {}).items():
- snippets.append(f'{attribute}: {value}.')
- return snippets
-
-
-def search(api_key, search_term, **kwargs):
- headers = {
- 'X-API-KEY': api_key,
- 'Content-Type': 'application/json',
- }
- params = {
- 'q': search_term,
- **{key: value
- for key, value in kwargs.items() if value is not None},
- }
- try:
- response = requests.post(
- 'https://google.serper.dev/search',
- headers=headers,
- params=params,
- timeout=5)
- except Exception as e:
- return -1, str(e)
- return response.status_code, response.json()
-
-
-def Search(q, k=10):
- status_code, response = search(SERPER_API_KEY, q)
- if status_code != 200:
- ret = 'None\n'
- else:
- text = parse_results(response, k=k)
- ret = ''
- for idx, res in enumerate(text):
- ret += f"<|{idx+1}|>: '{res}'\n"
- return ret
diff --git a/code/xtuner/tools/plugins/solve.py b/code/xtuner/tools/plugins/solve.py
deleted file mode 100644
index 20266a23f492cc5e7264d1a46398d64c94267579..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/plugins/solve.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import math
-import re
-from math import * # noqa: F401, F403
-
-from sympy import Eq, solve, symbols
-
-from .calculate import Calculate
-
-
-def Solve(equations_str):
- try:
- equations_str = equations_str.replace(' ', '')
- equations_ori = re.split(r'[,;]+', equations_str)
- equations_str = equations_str.replace('^', '**')
- equations_str = re.sub(r'(\(.*\))([a-zA-Z])', r'\1 * \2',
- equations_str)
- equations_str = re.sub(r'(\d+)([a-zA-Z])', r'\1 * \2', equations_str)
- equations_str = equations_str.replace('pi', str(math.pi))
- equations = re.split(r'[,;]+', equations_str)
- vars_list = list(set(re.findall(r'[a-zA-Z]+', equations_str)))
- vars = {var: symbols(var) for var in vars_list}
-
- output = ''
- eqs = []
- for eq in equations:
- if '=' in eq:
- left, right = eq.split('=')
- eqs.append(
- Eq(
- eval(left.strip(), {}, vars),
- eval(right.strip(), {}, vars)))
- solutions = solve(eqs, vars, dict=True)
-
- vars_values = {var: [] for var in vars_list}
- if isinstance(solutions, list):
- for idx, solution in enumerate(solutions):
- for var, sol in solution.items():
- output += f'{var}_{idx} = {sol}\n'
- vars_values[str(var)].append(sol)
- else:
- for var, sol in solutions.items():
- output += f'{var} = {sol}\n'
- vars_values[str(var)].append(sol)
- for eq, eq_o in zip(equations, equations_ori):
- if '=' not in eq:
- for var in vars_list:
- need_note = True if len(vars_values[var]) > 1 else False
- for idx, value in enumerate(vars_values[var]):
- eq_to_calc = eq.replace(var, str(value))
- calc_result = Calculate(eq_to_calc)
- if need_note:
- eq_name = eq_o.replace(var, f'{var}_{idx}')
- else:
- eq_name = eq_o
- if calc_result != 'No results.':
- output += f'{eq_name} = {calc_result}\n'
-
- return output.strip()
- except Exception:
- return 'No result.'
diff --git a/code/xtuner/tools/process_untokenized_datasets.py b/code/xtuner/tools/process_untokenized_datasets.py
deleted file mode 100644
index c41905ee6daaebca1f9e546b5588c6d627baea39..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/process_untokenized_datasets.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import warnings
-
-from mmengine import Config, ConfigDict
-from mmengine.config.lazy import LazyObject
-
-from xtuner.registry import BUILDER
-
-# ignore FutureWarning in hf datasets
-warnings.simplefilter(action='ignore', category=FutureWarning)
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--save-folder', help='The folder to save data order.')
- args = parser.parse_args()
- return args
-
-
-def modify_config(config, dataset_save_folder):
- dataset = ConfigDict(
- type=LazyObject('datasets', 'load_from_disk'),
- dataset_path=dataset_save_folder)
- train_dataset = ConfigDict(
- type=LazyObject('xtuner.dataset', 'process_hf_dataset'),
- dataset=dataset,
- do_dataset_tokenization=False,
- tokenizer=None,
- max_length=None,
- dataset_map_fn=None,
- template_map_fn=None,
- max_dataset_length=None,
- split=None,
- remove_unused_columns=False,
- rename_maps=[],
- pack_to_max_length=False,
- input_ids_with_output=False)
- config.train_dataloader.dataset = train_dataset
- return config
-
-
-def process_untokenized_dataset(config):
- dataset = BUILDER.build(config.train_dataloader.dataset)
- return dataset
-
-
-if __name__ == '__main__':
- args = parse_args()
- cfg = Config.fromfile(args.config)
-
- print('Start to process untokenized dataset...')
- processed_dataset = process_untokenized_dataset(cfg)
- print('Processing untokenized dataset finished.')
-
- processed_dataset_save_folder = args.save_folder
- if not os.path.isabs(processed_dataset_save_folder):
- processed_dataset_save_folder = os.path.join(
- os.getcwd(), processed_dataset_save_folder)
- modified_cfg = modify_config(cfg, processed_dataset_save_folder)
-
- print('Start to save processed dataset...')
- processed_dataset.save_to_disk(processed_dataset_save_folder)
- print(
- f'Processed dataset has been saved to {processed_dataset_save_folder}')
-
- cfg_folder, cfg_file_name = os.path.split(args.config)
- cfg_file_name = cfg_file_name.split('.')[0]
- cfg_file_name = f'{cfg_file_name}_modified.py'
- modified_cfg_save_path = os.path.join(cfg_folder, cfg_file_name)
- modified_cfg.dump(modified_cfg_save_path)
- print(f'Modified config has been saved to {modified_cfg_save_path}. '
- 'Please use this new config for the next training phase.')
diff --git a/code/xtuner/tools/process_untokenized_datasets_legacy.py b/code/xtuner/tools/process_untokenized_datasets_legacy.py
deleted file mode 100644
index 8b4dd5a7de93e2966b2bb3d9c579a2e4669db034..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/process_untokenized_datasets_legacy.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import ast
-import multiprocessing
-import os
-import warnings
-from functools import partial
-
-from datasets import Dataset, DatasetDict, load_dataset
-from mmengine import ConfigDict
-from transformers import AutoTokenizer
-
-from xtuner.dataset.huggingface import process
-from xtuner.dataset.map_fns import (DATASET_FORMAT_MAPPING,
- template_map_fn_factory)
-from xtuner.utils import PROMPT_TEMPLATE
-
-# ignore FutureWarning in hf datasets
-warnings.simplefilter(action='ignore', category=FutureWarning)
-"""
-ftdp dataset:
-srun -p llm_razor --quotatype=auto --gres=gpu:1 --ntasks=1 \
- --ntasks-per-node=1 --cpus-per-task=5 --kill-on-bad-exit=1 \
- python xtuner/tools/process_untokenized_datasets.py \
- --data-folder /path/to/data/folder \
- --save-folder ./processed \
- --tokenizer-path pretrained_model_name_or_path \
- --prompt-template internlm2_chat \
- --dataset-format ftdp
-
-normal json dataset:
-srun -p llm_razor --quotatype=auto --gres=gpu:1 --ntasks=1 \
- --ntasks-per-node=1 --cpus-per-task=5 --kill-on-bad-exit=1 \
- python xtuner/tools/process_untokenized_datasets.py \
- --data-folder /path/to/data/folder \
- --save-folder ./processed \
- --tokenizer-path pretrained_model_name_or_path \
- --prompt-template internlm2_chat
-"""
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--data-folder', help='Data folder')
- parser.add_argument('--save-folder', help='The folder to save data order.')
- parser.add_argument(
- '--tokenizer-path', help='The path to the hf tokenizer.')
- parser.add_argument(
- '--dataset-format',
- choices=list(DATASET_FORMAT_MAPPING.keys()) + ['ftdp'],
- default=None,
- help='Which dataset format is this data. The available choices are '
- f"{list(DATASET_FORMAT_MAPPING.keys()) + ['ftdp']}. ")
- parser.add_argument(
- '--prompt-template',
- choices=PROMPT_TEMPLATE.keys(),
- help='Which prompt template need to be added to the dataset. '
- f'The available choices are {PROMPT_TEMPLATE.keys()}')
- parser.add_argument(
- '--max-length', default=32768, help='Max sequence length.')
- parser.add_argument(
- '--pack-to-max-length',
- action='store_true',
- help='Whether to pack the dataset to the `max_length `.')
- parser.add_argument(
- '--file-type',
- default='.json',
- help='We want to get the order of the file in this type.')
- parser.add_argument(
- '--data-order-path',
- default=None,
- help=('The path to a txt file which contains the a list of data path.'
- ' It can be obtain by xtuner/tools/get_data_order.py script.'))
- args = parser.parse_args()
- return args
-
-
-def process_one(fp,
- tokenizer,
- max_length,
- pack_to_max_length,
- dataset_map_fn=None,
- template_map_fn=None,
- is_ftdp=False):
- dataset = []
- if is_ftdp:
- with open(fp) as file:
- lines = file.readlines()
- for line in lines:
- line = ast.literal_eval(line)
- dataset.append({'messages': line})
- dataset = Dataset.from_list(dataset)
- else:
- # load formal json data
- dataset = load_dataset('json', data_files=fp)
- dataset = dataset['train']
- dataset = process(
- dataset,
- tokenizer=tokenizer,
- max_length=max_length,
- dataset_map_fn=dataset_map_fn,
- template_map_fn=template_map_fn,
- remove_unused_columns=True,
- pack_to_max_length=pack_to_max_length,
- map_num_proc=32)
- return fp, dataset
-
-
-def process_untokenized_dataset(folder,
- tokenizer,
- max_length,
- pack_to_max_length,
- dataset_map_fn,
- prompt_template,
- data_order_path=None,
- file_type='.json',
- is_ftdp=False):
- assert os.path.exists(folder), f'{folder} does not exist.'
- datasets_dict = {}
-
- if data_order_path is not None:
- data_order = load_dataset(
- 'text', data_files=data_order_path, split='train')['text']
- for i, fp in enumerate(data_order):
- data_order[i] = os.path.join(folder, fp)
- else:
- triples = list(os.walk(folder, followlinks=True))
- data_order = []
- for root, dirs, files in triples:
- dirs.sort()
- for fn in sorted(files):
- if fn.endswith(file_type):
- fp = os.path.join(root, fn)
- data_order.append(fp)
- print('All file path: ', data_order)
-
- pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
- template_map_fn = ConfigDict(
- type=template_map_fn_factory, template=prompt_template)
- process_single = partial(
- process_one,
- tokenizer=tokenizer,
- max_length=max_length,
- pack_to_max_length=pack_to_max_length,
- dataset_map_fn=dataset_map_fn,
- template_map_fn=template_map_fn,
- is_ftdp=is_ftdp)
- out = pool.map(process_single, data_order)
- pool.close()
- pool.join()
- for idx, (key, dataset) in enumerate(out):
- assert data_order[idx] == key
- dataset = dataset.remove_columns('length')
- datasets_dict[str(idx)] = dataset
- datasets_dict = DatasetDict(datasets_dict)
- return datasets_dict
-
-
-if __name__ == '__main__':
- args = parse_args()
- tokenizer = ConfigDict(
- type=AutoTokenizer.from_pretrained,
- pretrained_model_name_or_path=args.tokenizer_path,
- trust_remote_code=True,
- padding_side='right')
-
- if args.dataset_format is None:
- dataset_map_fn = None
- elif args.dataset_format == 'ftdp':
- dataset_map_fn = DATASET_FORMAT_MAPPING['openai']
- else:
- dataset_map_fn = DATASET_FORMAT_MAPPING[args.dataset_format]
-
- datasets_dict = process_untokenized_dataset(
- args.data_folder,
- tokenizer,
- args.max_length,
- args.pack_to_max_length,
- dataset_map_fn,
- PROMPT_TEMPLATE[args.prompt_template],
- data_order_path=args.data_order_path,
- file_type=args.file_type,
- is_ftdp=args.dataset_format == 'ftdp')
- datasets_dict.save_to_disk(args.save_folder)
diff --git a/code/xtuner/tools/process_untokenized_llava_data.py b/code/xtuner/tools/process_untokenized_llava_data.py
deleted file mode 100644
index 4d0c075855734835d3a72a2c98ee7be38b85bfac..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/process_untokenized_llava_data.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import warnings
-
-from mmengine import Config
-
-from xtuner.registry import BUILDER
-
-# ignore FutureWarning in hf datasets
-warnings.simplefilter(action='ignore', category=FutureWarning)
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--save-folder', help='The folder to save data order.')
- args = parser.parse_args()
- return args
-
-
-def build_llava_dataset(config):
- dataset = BUILDER.build(config.train_dataloader.dataset)
- return dataset
-
-
-if __name__ == '__main__':
- args = parse_args()
- cfg = Config.fromfile(args.config)
-
- llava_dataset = build_llava_dataset(cfg)
- text_data = llava_dataset.text_data
-
- text_data.save_to_disk(args.save_folder)
diff --git a/code/xtuner/tools/test.py b/code/xtuner/tools/test.py
deleted file mode 100644
index 3dbbcb857a639a37e88cc8e5947e8f81d09c0dad..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test.py
+++ /dev/null
@@ -1,576 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import deepspeed
-import time # <<< ADDED: Import the time module
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from sympy import im
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-import torch
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-from xtuner.utils import PROMPT_TEMPLATE
-from PIL import Image
-import pandas as pd
-import numpy as np
-from transformers import GenerationConfig, StoppingCriteriaList
-
-import os
-from xtuner.model.llava_dim_reducer import TextGuidedVisualTokenAttentionReducer
-
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='test_output_csv')
- parser.add_argument(
- '--eval_output_path',
- default='slidechat_baseline_eval.txt',
- help='path to save evaluation results')
-
- parser.add_argument(
- '--torch-dtype',
- default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--work-dir',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument('--divprune_ratio', type = float, default = None, help='the ratio for divprune')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-def pairwise_l1_distance(matrix: torch.Tensor) -> torch.Tensor:
- """
- Compute the full pairwise L1 (Manhattan) distance matrix
- for an [N, D] tensor.
- """
- # torch.cdist with p=1 computes L1 distance
- return torch.cdist(matrix, matrix, p=1)
-
-def pairwise_cosine_similarity(matrix):
- norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
- cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
- return cosine_similarity
-
-
-def DivPrune(visual_feature_vectors, image_feature_length,
- cosine_matrix=None, threshold_ratio=0.1):
- threshold_terms = int(round(threshold_ratio * image_feature_length))
- if cosine_matrix is None:
- cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))
-
- s = torch.empty(threshold_terms, dtype=torch.long, device=visual_feature_vectors.device)
- for i in range(threshold_terms):
- if i == 0:
- m2 = cosine_matrix
- else:
- m2 = torch.index_select(cosine_matrix, 0, torch.index_select(s, 0, torch.arange(0, i, device=cosine_matrix.device)))
-
- if i == 0:
- scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
- else:
- scores = torch.min(m2, dim=0).values
-
- phrase_to_add_idx = torch.argmax(scores)
- s[i] = phrase_to_add_idx
- return s, cosine_matrix
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- # build the runner from config
- if 'runner_type' not in cfg:
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- model_kwargs = {
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
-
- state_dict = guess_load_checkpoint(args.checkpoint)
- # state_dict = torch.load(args.checkpoint, map_location='cpu')
- print(f'available keys in checkpoint: {state_dict.keys()}')
- runner.model.load_state_dict(state_dict, strict=False)
-
-
- ##############################qingq check loaded weights######################################
- missing_keys, unexpected_keys = runner.model.load_state_dict(state_dict, strict=False).missing_keys, \
- runner.model.load_state_dict(state_dict, strict=False).unexpected_keys
-
- print("✅ Missing keys (not in checkpoint):")
- for key in missing_keys:
- print(f" - {key}")
-
- print("\n⚠️ Unexpected keys (in checkpoint but not in model):")
- for key in unexpected_keys:
- print(f" - {key}")
- ##############################qingq check loaded weights######################################
-
-
-
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
-
- llm = runner.model.llm # dtype: float16
- llm.eval()
-
- LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype']) # torch.bfloat16
- LongNet_encoder.cuda()
- LongNet_encoder.eval()
-
- projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- projector.cuda()
- projector.eval()
-
- # Check for and apply the visual token reducer
- visual_token_reducer = None
- if hasattr(runner.model, 'visual_token_reducer'):
- print("Visual token reducer found, applying it.")
- visual_token_reducer = runner.model.visual_token_reducer.to(model_kwargs['torch_dtype'])
- visual_token_reducer.cuda()
- visual_token_reducer.eval()
- # projector = torch.nn.Sequential(projector, visual_token_reducer)
- # print('Using visual token reducer')
-
- mil = None
- if hasattr(runner.model, 'acmil'):
- print("ACMIL found, applying it.")
- mil = runner.model.acmil.to(model_kwargs['torch_dtype'])
- mil.cuda()
- mil.eval()
-
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- df_test_case['Output'] = df_test_case.apply(lambda x: '', axis=1)
- # <<< MODIFIED: Add 'GenerationTime' column for timing results
- columns = ['ID','Slide','Tumor','Broad Category','Narrow Category','Question','A','B','C','D','Answer','Output', 'GenerationTime']
- df_test_output = pd.DataFrame(columns=columns)
- generation_times = [] # <<< ADDED: List to store generation times
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- for i in range(df_test_case.shape[0]):
- # if i != 541:
- # continue
-
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- # only check the brca
- print('tumor type: ', args.tumor_type)
- if df_test_case.loc[i, 'Tumor'] != args.tumor_type: #'LUAD':
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
- tumor_name = df_test_case.loc[i, 'Tumor']
- case_name = df_test_case.loc[i, 'Slide']
-
- # test_image_file = "TCGA_patch_feat/" + df_test_case.loc[i, 'Tumor'] + "/" + case_name + ".csv"
- test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/' + str( tumor_name.lower() ) + '_224x224_b20_t15/pt_files/' + case_name + '.pt'
-
- # test_image_file = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
- # test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/pt_files/TCGA-A7-A0CJ-01Z-00-DX2.pt'
-
-
- # test_image_file = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A6VV-01Z-00-DX2.csv'
- # test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/pt_files/TCGA-A7-A6VV-01Z-00-DX2.pt'
-
- # for some missing files, skip
- if not os.path.exists(test_image_file):
- with open("/data/qingq/PathVLM/baselines/github/SlideChat/outputs/missing_WSI_log.txt", "w") as f: # use "a" to append instead of overwrite
- f.write(test_image_file + "\n")
- f.close()
- continue
-
- if test_image_file.endswith('.csv'):
- image = pd.read_csv(test_image_file) # shape: [num_patches, 513]
- image = image.iloc[:, :512]
- total_rows = image.shape[0]
- sample_num = 38400
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_df = image.iloc[indices]
- image = sampled_df.iloc[:sample_num]
- image = image.to_numpy().reshape(1, image.shape[0], 512) # (1, N, 512)
- image = torch.from_numpy(image)
-
- # qingq modify, our feature format is .pt file
- elif test_image_file.endswith('.pt'):
- image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- image = image.numpy()
- total_rows = image.shape[0]
- print('before sampling image shape', image.shape)
- sample_num = 2 # 38400. original 38400 is out of memory for 45G
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_image = image[indices]
- image = sampled_image[:sample_num]
-
- # image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- # image = image.numpy()
- # total_rows = image.shape[0]
- # print('before sampling image shape', image.shape)
-
- # sample_num = 2000 # 38400 was OOM for 45G
-
- # if total_rows >= sample_num:
- # indices = np.random.choice(total_rows, sample_num, replace=False)
- # sampled_image = image[indices]
- # image = sampled_image # already limited to sample_num
-
- # Reshape and convert to tensor: (1, N, 512)
- image = torch.from_numpy(image.reshape(1, -1, 512)) # final shape: (1, N, 512)
- print('final image shape', image.shape)
-
- else:
- image = Image.open(test_image_file).convert('RGB')
-
- image = image.cuda() # shape (1, patch_num, 512)
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = df_test_case.loc[i, 'Question']
- try:
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- options.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(options)
-
- sample_input = f"{question}\n{options_str}"
- print('Input: ', sample_input)
- except KeyError as e:
- sample_input = question
- print('Input: ', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids).cuda()
-
-
- # image.dtype: float32
- # runner.model.dtype: float16
- # image = image.to(runner.model.projector.dtype)
- # image = runner.model.LongNet_encoder(src_tokens=None, token_embeddings=image.permute(1, 0, 2).to(runner.model.llm.dtype))["encoder_out"]
-
-
- # model_kwargs['torch_dtype']
- image = image.to(projector.dtype) # shape (1, patch_num, 512)
- if mil is not None:
- # pixel_values shape: [1, patch_num, 512]
- _, image, _ = mil(image)
-
- pixel_values = image.unsqueeze(0) # shape: [1, patch_num, 512]
- pixel_values = projector(pixel_values) # shape: [1, patch_num,
- else:
- image = LongNet_encoder(src_tokens=None, token_embeddings=image.permute(1, 0, 2))["encoder_out"] # output shape (patch_num, 1, 512)
-
- image = image.permute(1, 0, 2) # shape: [1, patch_num, 512]
-
- # pixel_values = runner.model.projector(image)
- pixel_values = projector(image) # shape: [1, patch_num, 3584]
-
- if args.divprune_ratio is not None and args.divprune_ratio > 0 and args.divprune_ratio < 1.0:
-
- print('Applying divprune with ratio:', args.divprune_ratio)
- # Apply divprune
- pruned_batch_features = []
- for visual_tokens in pixel_values: # Iterate over the batch dimension
- img_feature_len = visual_tokens.shape[0]
- selected_indices, _ = DivPrune(
- visual_tokens,
- img_feature_len,
- threshold_ratio=args.divprune_ratio
- )
- selected_indices = torch.sort(selected_indices).values
- pruned_features = visual_tokens[selected_indices]
- pruned_batch_features.append(pruned_features)
-
- # Stack the list of pruned tensors back into a single batch tensor
- pixel_values = torch.stack(pruned_batch_features, dim=0)
- print('After divprune, pixel_values shape: ', pixel_values.shape)
-
- if visual_token_reducer is not None:
- is_text_guided_reducer = isinstance(
- visual_token_reducer, TextGuidedVisualTokenAttentionReducer
- )
- print("Applying visual token reducer")
-
-
- if is_text_guided_reducer:
- # Get text embeddings and attention mask for the guided reducer
- # input_ids = data['input_ids']
- input_ids = input_ids.unsqueeze(0) # Add batch dimension
- text_attention_mask = None
-
- text_embeddings = llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
-
-
- pixel_values = visual_token_reducer(
- pixel_values, text_embeddings, text_attention_mask
- )
- else:
- # Input to reducer is now (B, T, D)
- pixel_values = visual_token_reducer(pixel_values)
-
-
- print('After visual token reducer, pixel_values shape: ', pixel_values.shape)
- print('After visual token reducer, pixel_values dtype: ', pixel_values.dtype)
-
-
-
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=runner.model.llm,
- input_ids=input_ids.unsqueeze(0),
- # pixel_values= torch.zeros_like(pixel_values).to(model_kwargs['torch_dtype']), # <<< MODIFIED: Use zeros_like to match dtype
- pixel_values=pixel_values, # <<< MODIFIED: Use pixel_values directly
- )
-
- for key in mm_inputs.keys():
- print(key, mm_inputs[key])
-
- max_new_tokens=1000
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words=[]
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(
- StopWordStoppingCriteria(tokenizer, word))
-
- # <<< ADDED: Timing logic for the generation step
- torch.cuda.synchronize() # Synchronize to get accurate timing on GPU
- start_time = time.time()
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- torch.cuda.synchronize() # Wait for the generation to finish
- end_time = time.time()
- duration = end_time - start_time
- generation_times.append(duration)
-
- generation_output = tokenizer.decode(generate_output[0])
- if generation_output.endswith('<|im_end|>'):
- generation_output = generation_output[:-10]
-
- print('Output: ', generation_output)
- print(f'Generation Time: {duration:.4f} seconds') # <<< ADDED: Print time for this sample
- try:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'GenerationTime': duration # <<< ADDED: Save duration to the row
- }
- except:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'GenerationTime': duration # <<< ADDED: Save duration to the row
- }
- df_test_output.loc[i] = add_row
- if args.test_output_csv:
- df_test_output.to_csv(args.test_output_csv, index=False) # <<< MODIFIED: Added index=False for cleaner CSV
-
- print('Test ok!')
-
- # <<< ADDED: Calculate and display average generation time
- if generation_times:
- average_time = np.mean(generation_times)
- print(f"\nAverage Generation Time over {len(generation_times)} samples: {average_time:.4f} seconds")
-
-
- # check performance
- def slidechat_performance(output_file_path, eval_output_path):
-
- # Load the CSV
- df = pd.read_csv(output_file_path)
-
-
- # Clean ground-truth answers
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
-
- # Extract the letter before the period in 'Output' (e.g., 'A. Luminal A' → 'A')
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
-
- # Compute exact match
- df['correct'] = df['Answer_clean'] == df['Output_clean']
-
- # Calculate accuracy
- accuracy = df['correct'].mean()
-
- # <<< MODIFIED: Calculate average time from the saved CSV file
- average_gen_time = df['GenerationTime'].mean()
-
- # Print summary
- total = len(df)
- correct = df['correct'].sum()
- print(f"Exact Match Accuracy: {accuracy:.2%} ({correct}/{total})")
- print(f"Average Generation Time: {average_gen_time:.4f} seconds")
-
-
- # Build the result string
- # <<< MODIFIED: Add average time to the output text file
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}
- Average Generation Time : {average_gen_time:.4f} seconds
- """
-
- # Print to console
- print(output_file_path)
- print(result_text)
-
- # Save to txt file
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
- if args.test_output_csv:
- try:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
- except Exception as e:
- print(f"Error during performance evaluation: {e}")
- pass
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_dynamic_llava.py b/code/xtuner/tools/test_dynamic_llava.py
deleted file mode 100644
index e1a0686a691bfd292412b9a27cc5e23840144278..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_dynamic_llava.py
+++ /dev/null
@@ -1,413 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-# Test DynamicLLaVAQwen25 on SlideBench / TCGA features.
-import argparse
-import os
-import os.path as osp
-import time
-
-import numpy as np
-import pandas as pd
-import torch
-from PIL import Image
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria, PROMPT_TEMPLATE)
-from transformers import AutoTokenizer, GenerationConfig, StoppingCriteriaList
-
-# Optional (only used if you pass a ZeRO checkpoint directory)
-try:
- from deepspeed.checkpoint.utils import load_state_dict_from_zero_checkpoint
- _HAS_DS_ZERO = True
-except Exception:
- _HAS_DS_ZERO = False
-
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto'
-)
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test DynamicLLaVAQwen25 model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', required=True,
- help='Either a compact .pth OR a ZeRO shard directory')
- parser.add_argument('--test_slide_csv', required=True,
- help='Benchmark CSV with questions & meta')
- parser.add_argument('--test_output_csv', required=True,
- help='Where to write per-row outputs')
- parser.add_argument('--tumor_type', default=None,
- help='Filter rows by Tumor column (e.g., BRCA, SKCM)')
- parser.add_argument('--eval_output_path', default='slidechat_eval.txt',
- help='Path to append evaluation results')
- parser.add_argument('--torch-dtype', default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Aux dtype placement (usually bf16)')
- parser.add_argument('--divprune_ratio', type=float, default=None,
- help='Override model.divprune_ratio at test time (0 str:
- if osp.isfile(cfg_arg):
- return cfg_arg
- try:
- return cfgs_name_path[cfg_arg]
- except KeyError:
- raise FileNotFoundError(f'Cannot find config "{cfg_arg}"')
-
-
-def _load_model_weights(runner: Runner, ckpt_path: str):
- # """Load either a compact .pth or a ZeRO-3 directory into runner.model."""
- # if osp.isdir(ckpt_path):
- # if not _HAS_DS_ZERO:
- # raise RuntimeError(
- # "You provided a directory (ZeRO shards), but deepspeed is not available. "
- # "Install DeepSpeed or pass a compact .pth."
- # )
- # print(f"🔧 Loading ZeRO-3 shards from: {ckpt_path}")
- # load_state_dict_from_zero_checkpoint(runner.model, ckpt_path)
- # class _Compat: # for uniform logging
- # missing_keys, unexpected_keys = [], []
- # return _Compat()
- # else:
- # print(f"🔧 Loading compact checkpoint: {ckpt_path}")
- # state_dict = guess_load_checkpoint(ckpt_path)
- # return runner.model.load_state_dict(state_dict, strict=False)
- state_dict = guess_load_checkpoint(ckpt_path)
- print(f'available keys in checkpoint: {state_dict.keys()}')
- return runner.model.load_state_dict(state_dict, strict=False)
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # Resolve config path & load
- cfg_path = _maybe_resolve_config_path(args.config)
- cfg = Config.fromfile(cfg_path)
- cfg.launcher = args.launcher
-
- cfg_options = getattr(args, 'cfg_options', None)
- if cfg_options:
- cfg.merge_from_dict(cfg_options)
-
- work_dir = getattr(args, 'work_dir', None)
- if work_dir is not None:
- cfg.work_dir = work_dir
- elif cfg.get('work_dir', None) is None:
- cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(cfg_path))[0])
-
- # Build runner/model
- dtype = TORCH_DTYPE_MAP[args.torch_dtype]
- runner = Runner.from_cfg(cfg) if 'runner_type' not in cfg else RUNNERS.build(cfg)
-
- runner.model.llm = runner.model.llm.to(dtype=dtype)
- runner.model.LongNet_encoder = runner.model.LongNet_encoder.to(dtype=dtype)
- runner.model.projector = runner.model.projector.to(dtype=dtype)
- # Load weights
- incompat = _load_model_weights(runner, args.checkpoint)
- # --- NEW: Enhanced Checkpoint Loading Report ---
- print("\n--- Checkpoint Loading Report ---")
-
- missing_keys = getattr(incompat, 'missing_keys', [])
- unexpected_keys = getattr(incompat, 'unexpected_keys', [])
-
- if not missing_keys and not unexpected_keys:
- print("✅ Checkpoint loaded successfully with no mismatches.")
- else:
- if missing_keys:
- print("\n🔎 Missing keys (expected in model but NOT found in checkpoint):")
- for k in missing_keys:
- print(f" - {k}")
- if unexpected_keys:
- print("\n🔎 Unexpected keys (found in checkpoint but NOT in model):")
- for k in unexpected_keys:
- print(f" - {k}")
-
- # Check for specific module types
- has_lora_in_checkpoint = any('lora' in k for k in runner.model.state_dict())
- missing_lora = any('lora' in k for k in missing_keys) if has_lora_in_checkpoint else False
-
- has_predictor_in_checkpoint = any('predictor' in k for k in runner.model.state_dict())
- missing_predictor = any('predictor' in k for k in missing_keys) if has_predictor_in_checkpoint else False
-
- if has_lora_in_checkpoint:
- if missing_lora:
- print("\n⚠️ WARNING: LoRA weights seem to be missing from the checkpoint.")
- else:
- print("\n✅ LoRA weights appear to be present and loaded from the checkpoint.")
-
- if has_predictor_in_checkpoint:
- if missing_predictor:
- print("\n⚠️ WARNING: Predictor weights seem to be missing from the checkpoint.")
- else:
- print("\n✅ Predictor weights appear to be present and loaded from the checkpoint.")
-
- print("---------------------------------\n")
- # --- End of Report ---
-
- print('llm model', runner.model.llm)
-
- # Optional runtime override for DivPrune ratio
- if args.divprune_ratio is not None:
- if not (0.0 < args.divprune_ratio < 1.0):
- raise ValueError("--divprune_ratio must be in (0,1).")
- if hasattr(runner.model, "divprune_ratio"):
- print(f"Overriding model.divprune_ratio to {args.divprune_ratio}")
- runner.model.divprune_ratio = args.divprune_ratio
-
- runner.model.eval()
- runner.logger.info(f'Weights ready from {args.checkpoint}')
-
- # Tokenizer (prefer model.tokenizer if present)
- if hasattr(runner.model, "tokenizer"):
- tokenizer = runner.model.tokenizer
- else:
- # Fallback: Qwen2.5 tokenizer
- tokenizer = AutoTokenizer.from_pretrained(
- 'Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True, encode_special_tokens=True
- )
-
- # Read benchmark CSV
- if not args.test_slide_csv or not osp.exists(args.test_slide_csv):
- raise FileNotFoundError("--test_slide_csv is required and must exist.")
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- # Prepare output
- out_cols = ['ID', 'Slide', 'Tumor', 'Broad Category', 'Narrow Category',
- 'Question', 'A', 'B', 'C', 'D', 'Answer', 'Output', 'GenerationTime']
- df_out = pd.DataFrame(columns=out_cols)
- if args.test_output_csv:
- os.makedirs(osp.dirname(args.test_output_csv), exist_ok=True)
-
- # Prompt template
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
-
- gen_times = []
-
- for i in range(df_test_case.shape[0]):
- print('*' * 30)
- slide = df_test_case.loc[i, 'Slide']
- tumor = df_test_case.loc[i, 'Tumor']
- print('row id:', i, 'slide:', slide)
-
- if args.tumor_type is not None and tumor != args.tumor_type:
- continue
-
- # Path to precomputed features
- feat_path = (
- '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/' +
- str(tumor.lower()) + '_224x224_b20_t15/pt_files/' + slide + '.pt'
- )
- if not osp.exists(feat_path):
- missing_log = "/data/qingq/PathVLM/baselines/github/SlideChat/outputs/missing_WSI_log.txt"
- with open(missing_log, "a") as f:
- f.write(feat_path + "\n")
- print(f"⚠️ Missing features, logged: {feat_path}")
- continue
-
- # Load features
- if feat_path.endswith('.csv'):
- df = pd.read_csv(feat_path)
- df = df.iloc[:, :512]
- total = df.shape[0]
- sample_num = min(total, 10240)
- if total >= sample_num:
- idx = np.linspace(0, total - 1, sample_num, dtype=int)
- df = df.iloc[idx]
- image = torch.from_numpy(df.to_numpy().reshape(1, -1, 512))
- elif feat_path.endswith('.pt'):
- arr = torch.load(feat_path, map_location='cpu')
- if isinstance(arr, torch.Tensor):
- arr = arr.cpu().numpy()
- total = arr.shape[0]
- print('before sampling image shape', arr.shape)
- sample_num = min(total, 10240)
- if total >= sample_num:
- idx = np.linspace(0, total - 1, sample_num, dtype=int)
- arr = arr[idx]
- image = torch.from_numpy(arr.reshape(1, -1, 512))
- print('final image shape', image.shape, 'dtype', image.dtype)
- else:
- # Not expected in your pipeline
- _ = Image.open(feat_path).convert('RGB')
- raise ValueError("Expecting precomputed features (.pt/.csv), not raw images.")
-
- image = image.cuda(non_blocking=True)
-
- # Build prompt with one IMAGE token
- question = df_test_case.loc[i, 'Question']
- try:
- opts = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- opts.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(opts)
- sample_input = f"{question}\n{options_str}"
- except KeyError:
- sample_input = question
-
- print('Input:', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- full_text = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
-
- # Tokenize with an IMAGE token placeholder
- parts = full_text.split(DEFAULT_IMAGE_TOKEN)
- enc_parts = []
- for j, chunk in enumerate(parts):
- if j == 0:
- enc = tokenizer.encode(chunk)
- else:
- enc = tokenizer.encode(chunk, add_special_tokens=False)
- enc_parts.append(enc)
- assert len(enc_parts) == 2
-
- ids = []
- for j, enc in enumerate(enc_parts):
- ids.extend(enc)
- if j != len(enc_parts) - 1:
- ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(ids, device='cuda')
-
- with torch.no_grad():
- # Use the model’s internal pipeline: LongNet -> (DivPrune) -> Projector
- pixel_embeds = runner.model._encode_features(image.to(dtype =runner.model.llm.dtype))
- # print(f'Pixel embeds shape: {pixel_embeds.shape}')
- # Pack multimodal inputs and get segment indices
- packed, (input_embeds_indices,) = runner.model._prepare_inputs_labels_for_multimodal_dynamic(
- input_ids=input_ids.unsqueeze(0),
- labels=None,
- attention_mask=None,
- position_ids=None,
- past_key_values=None,
- pixel_values=pixel_embeds.to(runner.model.llm.dtype),
- return_indices=True
- )
- for k, v in packed.items():
- print(f"Packed {k}: {v.shape if isinstance(v, torch.Tensor) else v}")
-
- attn_impl = getattr(runner.model.llm.config, "_attn_implementation", "")
- attn = None if attn_impl == "flash_attention_2" else packed["attention_mask"]
-
- gen_config = GenerationConfig(
- max_new_tokens=1000,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words = []
- stop_words += PROMPT_TEMPLATE.qwen_chat.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList(
- [StopWordStoppingCriteria(tokenizer, w) for w in stop_words]
- )
-
- torch.cuda.synchronize()
- t0 = time.time()
- with torch.no_grad():
- out_ids = runner.model.llm.generate(
- inputs_embeds=packed["inputs_embeds"].to(runner.model.llm.dtype),
- attention_mask=packed.get("attention_mask", None),
- position_ids=packed.get("position_ids", None),
- generation_config=gen_config,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria,
- input_embeds_indices=input_embeds_indices,
- use_cache=True,
- )
- torch.cuda.synchronize()
- dt = time.time() - t0
- gen_times.append(dt)
-
- out_text = tokenizer.decode(out_ids[0], skip_special_tokens=False)
- if out_text.endswith('<|im_end|>'):
- out_text = out_text[:-10]
-
- print('Output:', out_text)
- print(f'Generation Time: {dt:.4f} s')
-
- # collect row
- row = {
- 'ID': df_test_case.loc[i].get('ID', i),
- 'Slide': df_test_case.loc[i].get('Slide', slide),
- 'Tumor': tumor,
- 'Broad Category': df_test_case.loc[i].get('Broad Category', ''),
- 'Narrow Category': df_test_case.loc[i].get('Narrow Category', ''),
- 'Question': question,
- 'A': df_test_case.loc[i].get('A', ''),
- 'B': df_test_case.loc[i].get('B', ''),
- 'C': df_test_case.loc[i].get('C', ''),
- 'D': df_test_case.loc[i].get('D', ''),
- 'Answer': df_test_case.loc[i].get('Answer', ''),
- 'Output': out_text,
- 'GenerationTime': dt
- }
- df_out.loc[i] = row
- df_out.to_csv(args.test_output_csv, index=False)
-
- print('Test loop finished.')
-
- if gen_times:
- avg_t = float(np.mean(gen_times))
- print(f"\nAverage Generation Time over {len(gen_times)} samples: {avg_t:.4f} seconds")
-
- # -------- Evaluation (Exact-match on first letter) --------
- def slidechat_performance(output_file_path, eval_output_path):
- df = pd.read_csv(output_file_path)
- if 'GenerationTime' not in df:
- df['GenerationTime'] = np.nan
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
- df['correct'] = df['Answer_clean'] == df['Output_clean']
- accuracy = df['correct'].mean()
- avg_gen_time = df['GenerationTime'].mean()
- total = len(df)
- correct = int(df['correct'].sum())
- summary = (
- f"Evaluation Summary:\n"
- f"---------------------\n"
- f"Total Samples : {total}\n"
- f"Correct : {correct}\n"
- f"Accuracy : {accuracy:.2%}\n"
- f"Average Generation Time : {avg_gen_time:.4f} seconds\n"
- )
- print(output_file_path)
- print(summary)
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path + '\n')
- f.write(summary + '\n')
-
- try:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
- except Exception as e:
- print(f"Error during performance evaluation: {e}")
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_fastv_llava.py b/code/xtuner/tools/test_fastv_llava.py
deleted file mode 100644
index bde1f094baac0d3a20397de39cdfc84748e6b86b..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_fastv_llava.py
+++ /dev/null
@@ -1,504 +0,0 @@
-#!/usr/bin/env python3
-"""
-FastV-enabled inference script for SlideChat model with visual token pruning.
-"""
-
-import argparse
-import os
-import os.path as osp
-import sys
-import time
-
-import numpy as np
-import pandas as pd
-import torch
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.utils import (
- DEFAULT_IMAGE_TOKEN,
- IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria,
- PROMPT_TEMPLATE,
-)
-from transformers import AutoTokenizer, GenerationConfig, StoppingCriteriaList
-
-# Make local fastv_qwen importable if needed (sibling file to this script)
-sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto="auto"
-)
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="Test SlideChat model with FastV pruning")
- parser.add_argument("config", help="config file name or path.")
- parser.add_argument("--checkpoint", required=True,
- help="Model checkpoint path (.pth file or directory)")
- parser.add_argument("--test_slide_csv", required=True,
- help="Benchmark CSV with questions & metadata")
- parser.add_argument("--test_output_csv", required=True,
- help="Where to write per-row outputs")
- parser.add_argument("--tumor_type", default=None,
- help="Filter rows by Tumor column (e.g., BRCA, SKCM)")
- parser.add_argument("--eval_output_path", default="slidechat_eval_fastv.txt",
- help="Path to append evaluation results")
- parser.add_argument("--torch-dtype", default="bf16",
- choices=TORCH_DTYPE_MAP.keys(),
- help="Model dtype (usually bf16)")
- parser.add_argument("--fastv_k", type=int, default=2,
- help="Decoder layer index to begin FastV pruning (1-indexed agg layer; pruning uses attention from k-1)")
- parser.add_argument("--fastv_r", type=float, default=0.5,
- help="Fraction of visual tokens to retain after pruning (0 str:
- """Resolve config names via xtuner.configs lookup if necessary."""
- if osp.isfile(cfg_arg):
- return cfg_arg
- try:
- return cfgs_name_path[cfg_arg]
- except KeyError as e:
- raise FileNotFoundError(f'Cannot find config "{cfg_arg}"') from e
-
-
-def _load_model_weights(runner: Runner, ckpt_path: str):
- """Load weights into the runner's model from a checkpoint file or folder."""
- # Accept a directory containing (possibly) a safetensors/pt/ckpt file or a direct file
- if osp.isdir(ckpt_path):
- # Try a few common filenames
- candidates = [
- osp.join(ckpt_path, "pytorch_model.bin"),
- osp.join(ckpt_path, "model.safetensors"),
- osp.join(ckpt_path, "adapter_model.bin"),
- osp.join(ckpt_path, "adapter_model.safetensors"),
- ]
- file_path = None
- for c in candidates:
- if osp.exists(c):
- file_path = c
- break
- if file_path is None:
- # fallback to xtuner helper
- state_dict = guess_load_checkpoint(ckpt_path)
- else:
- state_dict = torch.load(file_path, map_location="cpu") if file_path.endswith(".bin") else None
- if state_dict is None and file_path.endswith(".safetensors"):
- from safetensors.torch import load_file
- state_dict = load_file(file_path, device="cpu")
- else:
- state_dict = guess_load_checkpoint(ckpt_path)
-
- print(f'Available keys in checkpoint: {list(state_dict.keys())[:10]}...') # Show first 10 keys
- return runner.model.load_state_dict(state_dict, strict=False)
-
-
-def _find_lm_module(model):
- """
- Heuristically find the module that implements .generate() and owns the embedding/lm_head.
- Prefers Qwen2ForCausalLM if present; otherwise falls back to model itself if it has generate.
- """
- # Preferred: a top-level CausalLM wrapper
- for name in ["llm", "model", "decoder", "language_model"]:
- mod = getattr(model, name, None)
- if mod is not None and hasattr(mod, "generate"):
- return mod
-
- # If model itself has generate (e.g., CausalLM wrapper used directly)
- if hasattr(model, "generate"):
- return model
-
- # If nothing found, raise a descriptive error
- raise RuntimeError(
- "Could not find a module with `.generate()`. Expected Qwen2ForCausalLM (preferred) "
- "or a model exposing `.generate()`. Ensure the loaded model config uses a CausalLM "
- "wrapper (e.g., Qwen2ForCausalLM) around Qwen2Model for inference."
- )
-
-
-def _set_fastv(model_or_lm, config_dict: dict):
- """
- Apply FastV configuration either via `set_fastv_config` (preferred) or, if the
- module is a bare Qwen2Model, via the internal `_apply_fastv_config`.
- """
- if hasattr(model_or_lm, "set_fastv_config"):
- model_or_lm.set_fastv_config(config_dict)
- elif hasattr(model_or_lm, "_apply_fastv_config"):
- model_or_lm._apply_fastv_config(config_dict)
- elif hasattr(model_or_lm, "model") and hasattr(model_or_lm.model, "_apply_fastv_config"):
- # Some wrappers may expose underlying .model
- model_or_lm.model._apply_fastv_config(config_dict)
- else:
- print("Warning: No FastV configuration hook found; FastV may be inactive.")
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # Validate FastV parameters
- if not (0 < args.fastv_r <= 1):
- raise ValueError("--fastv_r must be in (0, 1]")
- if args.fastv_k < 0:
- raise ValueError("--fastv_k must be >= 0")
-
- # Resolve config path & load
- cfg_path = _maybe_resolve_config_path(args.config)
- cfg = Config.fromfile(cfg_path)
- cfg.launcher = args.launcher
-
- cfg_options = getattr(args, 'cfg_options', None)
- if cfg_options:
- cfg.merge_from_dict(cfg_options)
-
- work_dir = getattr(args, 'work_dir', None)
- if work_dir is not None:
- cfg.work_dir = work_dir
- elif cfg.get('work_dir', None) is None:
- cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(cfg_path))[0])
-
- # Build runner/model
- dtype = TORCH_DTYPE_MAP[args.torch_dtype]
- runner = Runner.from_cfg(cfg) if 'runner_type' not in cfg else RUNNERS.build(cfg)
-
- # dtype cast on the whole model
- runner.model = runner.model.to(dtype=dtype)
-
- # Load weights
- incompat = _load_model_weights(runner, args.checkpoint)
- print("\n--- Checkpoint Loading Report ---")
- missing_keys = getattr(incompat, 'missing_keys', [])
- unexpected_keys = getattr(incompat, 'unexpected_keys', [])
- if not missing_keys and not unexpected_keys:
- print("✅ Checkpoint loaded successfully with no mismatches.")
- else:
- if missing_keys:
- print(f"\n🔎 Missing keys: {len(missing_keys)} keys")
- for k in missing_keys[:10]:
- print(f" - {k}")
- if unexpected_keys:
- print(f"\n🔎 Unexpected keys: {len(unexpected_keys)} keys")
- for k in unexpected_keys[:10]:
- print(f" - {k}")
- print("---------------------------------\n")
-
- runner.model.eval()
- runner.logger.info(f'Model ready from {args.checkpoint}')
-
- # Tokenizer
- if hasattr(runner.model, "tokenizer"):
- tokenizer = runner.model.tokenizer
- else:
- # Fallback tokenizer; adjust if your training used a specific repo path
- tokenizer = AutoTokenizer.from_pretrained(
- 'Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True, encode_special_tokens=True
- )
-
- # Locate the LM module for generation and embeddings
- lm = _find_lm_module(runner.model)
-
- # Read benchmark CSV
- if not args.test_slide_csv or not osp.exists(args.test_slide_csv):
- raise FileNotFoundError("--test_slide_csv must exist.")
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- # Prepare output DataFrame
- out_cols = [
- 'ID', 'Slide', 'Tumor', 'Broad Category', 'Narrow Category',
- 'Question', 'A', 'B', 'C', 'D', 'Answer', 'Output', 'GenerationTime'
- ]
- df_out = pd.DataFrame(columns=out_cols)
- if args.test_output_csv:
- os.makedirs(osp.dirname(args.test_output_csv) or '.', exist_ok=True)
-
- # Prompt template
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
-
- gen_times = []
- processed_count = 0
-
- # Main inference loop
- for i in range(df_test_case.shape[0]):
- print('*' * 30)
- slide = df_test_case.loc[i, 'Slide']
- tumor = df_test_case.loc[i, 'Tumor']
- print(f'Processing row {i}: slide={slide}, tumor={tumor}')
-
- # Tumor filter (optional)
- if args.tumor_type is not None and tumor != args.tumor_type:
- continue
-
- # Path to precomputed features
- feat_path = (
- '/data/weimin/PathVLM/dataset/TCGA_features/conch_v1/' +
- str(tumor.lower()) + '_224x224_b20_t15/pt_files/' + slide + '.pt'
- )
-
- if not osp.exists(feat_path):
- # Try alternative path or log missing
- alt_feat_path = feat_path.replace('.pt', '.csv')
- if osp.exists(alt_feat_path):
- feat_path = alt_feat_path
- else:
- missing_log = "missing_WSI_log.txt"
- with open(missing_log, "a") as f:
- f.write(feat_path + "\n")
- print(f"⚠️ Missing features: {feat_path}")
- continue
-
- # Load features
- try:
- if feat_path.endswith('.csv'):
- df = pd.read_csv(feat_path)
- df = df.iloc[:, :512]
- total = df.shape[0]
- sample_num = min(total, 10240)
- if total >= sample_num:
- idx = np.linspace(0, total - 1, sample_num, dtype=int)
- df = df.iloc[idx]
- image = torch.from_numpy(df.to_numpy().reshape(1, -1, 512))
- elif feat_path.endswith('.pt'):
- arr = torch.load(feat_path, map_location='cpu')
- if isinstance(arr, torch.Tensor):
- arr = arr.cpu().numpy()
- total = arr.shape[0]
- sample_num = min(total, 10240)
- if total >= sample_num:
- idx = np.linspace(0, total - 1, sample_num, dtype=int)
- arr = arr[idx]
- image = torch.from_numpy(arr.reshape(1, -1, 512))
- else:
- print(f"Unsupported file format: {feat_path}")
- continue
-
- image = image.cuda(non_blocking=True).to(dtype=dtype)
- print(f'Loaded image shape: {image.shape}, dtype: {image.dtype}')
-
- except Exception as e:
- print(f"Error loading features from {feat_path}: {e}")
- continue
-
- # Build prompt with IMAGE token
- question = df_test_case.loc[i, 'Question']
- try:
- opts = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- opts.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(opts)
- sample_input = f"{question}\n{options_str}"
- except KeyError:
- sample_input = question
-
- print(f'Input: {sample_input[:200]}...' if len(sample_input) > 200 else f'Input: {sample_input}')
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- full_text = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
-
- # Tokenize with IMAGE token placeholder
- parts = full_text.split(DEFAULT_IMAGE_TOKEN)
- enc_parts = []
- for j, chunk in enumerate(parts):
- if j == 0:
- enc = tokenizer.encode(chunk)
- else:
- enc = tokenizer.encode(chunk, add_special_tokens=False)
- enc_parts.append(enc)
-
- ids = []
- for j, enc in enumerate(enc_parts):
- ids.extend(enc)
- if j != len(enc_parts) - 1:
- ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(ids, device='cuda')
-
- try:
- with torch.no_grad():
- # Encode features through LongNet and projector on the main model
- # (Qwen2Model exposes LongNet_encoder/projector; wrappers keep them on the outer model)
- main_model = runner.model
- if hasattr(main_model, "LongNet_encoder") and hasattr(main_model, "projector"):
- pixel_embeds = main_model.LongNet_encoder(image)
- pixel_embeds = main_model.projector(pixel_embeds).to(dtype=dtype)
- else:
- raise RuntimeError(
- "Model does not expose LongNet_encoder/projector. Ensure your model is Qwen2Model-based."
- )
-
- # Prepare multimodal inputs: replace the IMAGE token with visual embeddings
- # We’ll map the single IMAGE_TOKEN_INDEX span to the full sequence of pixel_embeds
- inputs_embeds = lm.get_input_embeddings()(input_ids.unsqueeze(0))
- image_token_mask = (input_ids == IMAGE_TOKEN_INDEX)
- if image_token_mask.any():
- insert_at = image_token_mask.nonzero(as_tuple=False)[0].item()
- # Rebuild inputs_embeds: [text_before, pixel_embeds, text_after]
- before = inputs_embeds[:, :insert_at, :]
- after = inputs_embeds[:, insert_at + 1 :, :]
- inputs_embeds = torch.cat([before, pixel_embeds, after], dim=1)
- # else: no image token found; proceed with plain text
-
- # Compute FastV configuration
- image_token_mask = (input_ids == IMAGE_TOKEN_INDEX)
- if image_token_mask.any():
- sys_length = int(image_token_mask.nonzero(as_tuple=False)[0]) # number of text tokens before image
- else:
- sys_length = 0
- image_token_length = pixel_embeds.size(1)
-
- keep_rank = max(1, int(image_token_length * args.fastv_r))
- fastv_config = {
- 'use_fastv': True,
- 'fastv_k': args.fastv_k,
- 'fast_v_agg_layer': args.fastv_k,
- 'fast_v_attention_rank': keep_rank,
- 'fast_v_sys_length': sys_length,
- 'fast_v_image_token_length': image_token_length,
- 'fast_v_inplace': True,
- }
- print(f"FastV config: keeping {keep_rank}/{image_token_length} visual tokens at layer {args.fastv_k}")
-
- # Apply FastV to whichever module actually runs the forward pass
- _set_fastv(lm, fastv_config)
-
- # Generation configuration
- gen_config = GenerationConfig(
- max_new_tokens=1000,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
-
- stop_words = []
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList([
- StopWordStoppingCriteria(tokenizer, w) for w in stop_words
- ])
-
- # Generate (disable cache with FastV)
- torch.cuda.synchronize()
- t0 = time.time()
- with torch.no_grad():
- out_ids = lm.generate(
- inputs_embeds=inputs_embeds.to(dtype=dtype),
- attention_mask=None,
- position_ids=None,
- generation_config=gen_config,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria,
- use_cache=False,
- )
- torch.cuda.synchronize()
- dt = time.time() - t0
- gen_times.append(dt)
-
- out_text = tokenizer.decode(out_ids[0], skip_special_tokens=False)
- if out_text.endswith('<|im_end|>'):
- out_text = out_text[:-10]
-
- print(f'Output: {out_text[:200]}...' if len(out_text) > 200 else f'Output: {out_text}')
- print(f'Generation Time: {dt:.4f} s')
-
- processed_count += 1
-
- except Exception as e:
- print(f"Error during generation: {e}")
- out_text = "ERROR"
- dt = -1
-
- # Collect row
- row = {
- 'ID': df_test_case.loc[i].get('ID', i),
- 'Slide': slide,
- 'Tumor': tumor,
- 'Broad Category': df_test_case.loc[i].get('Broad Category', ''),
- 'Narrow Category': df_test_case.loc[i].get('Narrow Category', ''),
- 'Question': question,
- 'A': df_test_case.loc[i].get('A', ''),
- 'B': df_test_case.loc[i].get('B', ''),
- 'C': df_test_case.loc[i].get('C', ''),
- 'D': df_test_case.loc[i].get('D', ''),
- 'Answer': df_test_case.loc[i].get('Answer', ''),
- 'Output': out_text,
- 'GenerationTime': dt,
- }
- df_out.loc[len(df_out)] = row
-
- # Save incrementally
- if processed_count % 10 == 0:
- df_out.to_csv(args.test_output_csv, index=False)
- print(f"Saved {processed_count} results to {args.test_output_csv}")
-
- # Final save
- df_out.to_csv(args.test_output_csv, index=False)
- print(f'\nTest loop finished. Processed {processed_count} samples.')
-
- if gen_times:
- avg_t = float(np.mean(gen_times))
- print(f"Average Generation Time: {avg_t:.4f} seconds")
- print(f"FastV configuration used: k={args.fastv_k}, r={args.fastv_r}")
-
- # Evaluation
- def evaluate_performance(output_file_path, eval_output_path):
- """Evaluate exact-match accuracy on multiple choice questions."""
- try:
- df = pd.read_csv(output_file_path)
- if 'GenerationTime' not in df:
- df['GenerationTime'] = np.nan
-
- # Extract first letter from output as answer
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
- df['correct'] = df['Answer_clean'] == df['Output_clean']
-
- accuracy = df['correct'].mean()
- valid_times = df['GenerationTime'][(df['GenerationTime'] > 0) & df['GenerationTime'].notna()]
- avg_gen_time = valid_times.mean() if len(valid_times) > 0 else 0.0
- total = len(df)
- correct = int(df['correct'].sum())
-
- summary = (
- f"FastV Evaluation Summary (k={args.fastv_k}, r={args.fastv_r}):\n"
- f"---------------------\n"
- f"Total Samples : {total}\n"
- f"Correct : {correct}\n"
- f"Accuracy : {accuracy:.2%}\n"
- f"Average Generation Time : {avg_gen_time:.4f} seconds\n"
- f"FastV Settings : k={args.fastv_k}, r={args.fastv_r}\n"
- )
- print(output_file_path)
- print(summary)
-
- with open(eval_output_path, 'a+') as f:
- f.write(f"\n{output_file_path}\n")
- f.write(summary + '\n')
-
- except Exception as e:
- print(f"Error during evaluation: {e}")
-
- if args.test_output_csv and osp.exists(args.test_output_csv):
- evaluate_performance(args.test_output_csv, args.eval_output_path)
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/test_for_json_files.py b/code/xtuner/tools/test_for_json_files.py
deleted file mode 100644
index b71bb701077c47863a46261a1a2ef0245b7b4aa0..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_for_json_files.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Folder inference for *_test.json using LLaVAModel's image path:
- feats/coords -> (optional) token_merge -> projector -> (pos-embed + perceiver) -> llm.generate
-
-This script is updated to also handle simpler, projector-only models by making
-the token_merge, perceiver, and positional embedding steps conditional.
-
-REQUIRES .h5 feature files that contain BOTH:
- - 'features': float32, shape [N, 512]
- - 'coords': int/long or float, shape [N, 2] (xy in pixels)
-
-Why HDF5 only? LLaVAModel needs coords; PT/CSV don't provide them here.
-For projector-only models, the coords are loaded but may be ignored.
-"""
-
-import argparse
-import os
-import os.path as osp
-import json
-import glob
-import time
-from types import FunctionType
-
-import numpy as np
-import pandas as pd
-import torch
-import h5py
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from mmengine import print_log
-
-from transformers import AutoTokenizer, GenerationConfig, StoppingCriteriaList
-
-from xtuner.configs import cfgs_name_path
-from xtuner.registry import MAP_FUNC
-from xtuner.model.utils import guess_load_checkpoint, prepare_inputs_labels_for_multimodal
-from xtuner.utils import (
- DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria, PROMPT_TEMPLATE
-)
-
-TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-
-# -------------------- CLI --------------------
-def parse_args():
- ap = argparse.ArgumentParser("Folder JSON inference using LLaVAModel-compatible pipeline")
- ap.add_argument('config', help='xtuner/mmengine config (file path or cfgs_name_path key)')
- ap.add_argument('--checkpoint', default=None)
- ap.add_argument('--data_dir', required=True, help='Folder with many *_test.json')
- ap.add_argument('--feature_root', required=True,
- help='Root like .../conch_v1 (contains /h5_files/.h5)')
- ap.add_argument('--identifier', default='_224x224_b20_t15', help='Suffix directory appended to tumor folder')
- ap.add_argument('--image_feature_suffix', default='.h5', choices=['.h5'], # coords required → .h5 only
- help='Use .h5 because the model requires coords')
- ap.add_argument('--out_csv', required=True)
- ap.add_argument('--out_jsonl', default=None)
- ap.add_argument('--torch-dtype', default='bf16', choices=TORCH_DTYPE_MAP.keys())
- ap.add_argument('--work-dir', default=None)
- ap.add_argument('--cfg-options', nargs='+', action=DictAction)
- ap.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none')
- ap.add_argument('--local_rank', '--local-rank', type=int, default=0)
- # sampling options (match dataset behavior)
- ap.add_argument('--sample_num', type=int, default=10240)
- ap.add_argument('--sample_strategy', default='linspace', choices=['linspace', 'random', 'random_full'])
- ap.add_argument('--debug_max_samples', type=int, default=None)
- ap.add_argument('--max_new_tokens', type=int, default=512)
- return ap.parse_args()
-
-
-# -------------------- helpers --------------------
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for k, v in dict.items(cfg_dict):
- if isinstance(v, FunctionType):
- vs = str(v)
- if vs not in MAP_FUNC:
- MAP_FUNC.register_module(module=v, name=vs)
- cfg_dict[k] = vs
- else:
- register_function(v)
- elif isinstance(cfg_dict, (list, tuple)):
- for v in cfg_dict:
- register_function(v)
-
-
-def _parse_stub(image_path: str):
- """
- Extract tumor abbreviation (after the last '-') and the case (stem before the first '.').
- Example:
- image="TCGA-LUSC/TCGA-21-1083-01Z-00-DX1.8288c...pt"
- -> tumor_abbrev='lusc', case='TCGA-21-1083-01Z-00-DX1'
- """
- norm = os.path.normpath(image_path)
- parts = norm.split(os.sep)
-
- # parent folder like "TCGA-LUSC" -> "lusc"
- if len(parts) >= 2:
- parent = parts[-2]
- else:
- parent = 'unknown'
- tumor_abbrev = parent.split('-')[-1].lower() if '-' in parent else parent.lower()
-
- # file stem before first dot drops UUID
- stem = os.path.splitext(parts[-1])[0]
- case = stem.split('.', 1)[0]
- return tumor_abbrev, case
-
-
-def _build_feature_path(feature_root: str, tumor_abbrev: str, case_name: str,
- identifier: str, suffix: str):
- """
- Conch v1 layout:
- //h5_files/.h5
- Example:
- feature_root=/data/qingq/PathVLM/dataset/TCGA_features/conch_v1
- tumor_abbrev=lusc
- identifier=_224x224_b20_t15
- -> /data/.../conch_v1/lusc_224x224_b20_t15/h5_files/TCGA-...DX1.h5
- """
- if suffix != ".h5":
- raise ValueError("This pipeline requires .h5 features (with coords).")
- subdir = "h5_files"
- return os.path.join(feature_root, f"{tumor_abbrev}{identifier}", subdir, case_name + suffix)
-
-
-def _choose_indices(total_rows: int, k: int, strategy: str, rng: np.random.Generator):
- if total_rows <= 0:
- return np.array([], dtype=int)
- if strategy == "random_full":
- replace = total_rows < k
- idx = rng.choice(total_rows, size=k, replace=replace)
- return np.sort(idx.astype(int))
- if strategy == "random":
- if total_rows <= k:
- return np.arange(total_rows, dtype=int)
- idx = rng.choice(total_rows, size=k, replace=False)
- return np.sort(idx.astype(int))
- # linspace
- if total_rows <= k:
- return np.arange(total_rows, dtype=int)
- step = total_rows / k
- start = int(rng.integers(0, max(1, int(step))))
- indices = (np.floor(np.arange(k) * step + start)).astype(int)
- return np.clip(indices, 0, total_rows - 1)
-
-
-def _rng(seed: int = 3407):
- base = seed & 0xFFFFFFFF
- seed_mix = (base ^ (torch.initial_seed() & 0xFFFFFFFF)) & 0xFFFFFFFF
- return np.random.default_rng(seed_mix)
-
-
-def load_features_and_coords(
- image_field,
- feature_root: str,
- identifier: str,
- suffix: str,
- sample_num: int,
- sample_strategy: str,
- rng: np.random.Generator
-):
- """
- Returns:
- feats: torch.FloatTensor [T,512]
- coords: torch.Tensor [T,2] (int/float ok)
- meta_path: str
- """
- rel_path = image_field[0] if isinstance(image_field, list) else image_field
- tumor_name, case_base = _parse_stub(rel_path)
- feat_path = _build_feature_path(feature_root, tumor_name, case_base, identifier, suffix)
- if not osp.exists(feat_path):
- raise FileNotFoundError(f'feature not found: {feat_path}')
-
- with h5py.File(feat_path, 'r') as f:
- feats_np = f['features'][:] # (N,512)
- coords_np = f['coords'][:] # (N,2)
- if feats_np.shape[0] != coords_np.shape[0]:
- raise ValueError(f"Rows mismatch: features {feats_np.shape[0]} vs coords {coords_np.shape[0]} in {feat_path}")
-
- feats_np = feats_np.astype(np.float32, copy=False)
- total = feats_np.shape[0]
- idx = _choose_indices(total, sample_num, sample_strategy, rng)
-
- feats = torch.from_numpy(feats_np[idx]).float() # [T,512]
- coords = torch.from_numpy(coords_np[idx]) # [T,2], dtype may be int/float
- return feats, coords, feat_path
-
-
-def build_input_ids_from_prompt(tokenizer, raw_prompt: str):
- """
- Apply Qwen chat template and splice IMAGE_TOKEN_INDEX at .
- """
- prompt_tmpl = PROMPT_TEMPLATE.qwen_chat
- instruction = prompt_tmpl.get('INSTRUCTION', '{input}')
- if '' in raw_prompt:
- msg = instruction.format(input=raw_prompt, round=1)
- chunks = msg.split('')
- else:
- msg = instruction.format(input=(DEFAULT_IMAGE_TOKEN + '\n' + raw_prompt), round=1)
- chunks = msg.split(DEFAULT_IMAGE_TOKEN)
-
- encoded = []
- for i, ch in enumerate(chunks):
- # first chunk uses add_special_tokens=True, others False
- ids = tokenizer.encode(ch, add_special_tokens=(i == 0))
- encoded.extend(ids)
- if i != len(chunks) - 1:
- encoded.append(IMAGE_TOKEN_INDEX)
- return torch.tensor(encoded, dtype=torch.long)
-
-
-# -------------------- main --------------------
-def main():
- args = parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
- dtype = TORCH_DTYPE_MAP[args.torch_dtype]
-
- # ----- config / runner / model -----
- cfg_path = args.config if osp.isfile(args.config) else cfgs_name_path.get(args.config, None)
- if not cfg_path:
- raise FileNotFoundError(f'Cannot find config {args.config}')
- cfg = Config.fromfile(cfg_path)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
- register_function(cfg._cfg_dict)
-
- cfg.work_dir = args.work_dir or cfg.get('work_dir') or osp.join('./work_dirs', osp.splitext(osp.basename(cfg_path))[0])
-
- runner = Runner.from_cfg(cfg) if 'runner_type' not in cfg else RUNNERS.build(cfg)
- if args.checkpoint:
- sd = guess_load_checkpoint(args.checkpoint)
- missing, unexpected = runner.model.load_state_dict(sd, strict=False)
- print(f'✅ missing keys: {len(missing)} | ⚠️ unexpected: {len(unexpected)}')
-
- model = runner.model.eval().cuda()
- llm = model.llm.eval()
- projector = model.projector.eval()
- token_merge = getattr(model, 'token_merge', None)
- perceiver = getattr(model, 'perceiver', None)
- use_perceiver = bool(getattr(model, 'use_perceiver_resampler', False) and perceiver is not None)
- slide_ngrids = int(getattr(model, 'slide_ngrids', 1000))
- pe_drop = getattr(model, 'pe_drop', None)
- pe_gate = getattr(model, 'pe_gate', None)
- pos_embed = getattr(model, 'pos_embed', None)
-
- # ----- tokenizer / decoding -----
- llm_name_or_path = getattr(llm.config, "_name_or_path", "Qwen/Qwen2.5-7B-Instruct")
- tokenizer = AutoTokenizer.from_pretrained(llm_name_or_path, trust_remote_code=True, encode_special_tokens=True)
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- stop_words = prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList([StopWordStoppingCriteria(tokenizer, w) for w in stop_words])
- gen_config = GenerationConfig(
- max_new_tokens=args.max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
-
- # ----- collect *_test.json -----
- json_files = sorted(glob.glob(osp.join(args.data_dir, '*_test.json')))
- if not json_files:
- raise FileNotFoundError(f'No *_test.json found in {args.data_dir}')
-
- # outputs
- rows, times = [], []
- if args.out_jsonl:
- os.makedirs(osp.dirname(args.out_jsonl) or '.', exist_ok=True)
- jf = open(args.out_jsonl, 'w', encoding='utf-8')
- else:
- jf = None
- os.makedirs(osp.dirname(args.out_csv) or '.', exist_ok=True)
-
- rng = _rng(3407)
-
- for jfpath in json_files:
- print(f'\n=== {osp.basename(jfpath)} ===')
- with open(jfpath, 'r', encoding='utf-8') as f:
- data = json.load(f)
- samples = data if isinstance(data, list) else [data]
- if args.debug_max_samples is not None:
- samples = samples[: int(args.debug_max_samples)]
-
- for sample in samples:
- sid = str(sample.get('id', ''))
- image_field = sample.get('image', '')
- conv = sample.get('conversations', [])
- human_msgs = [c for c in conv if c.get('from') == 'human']
- raw_prompt = human_msgs[0]['value'] if human_msgs else '\nDescribe the slide.'
-
- # --- extract reference/ground-truth answer from test json (if present) ---
- def _get_ref_answer(conversations):
- if not conversations:
- return ""
- # Prefer 'gpt' like your example, but fall back to a few common keys
- priority_roles = ('gpt', 'assistant', 'answer', 'ref', 'label')
- for role in priority_roles:
- vals = [c.get('value', '') for c in conversations if str(c.get('from', '')).lower() == role]
- if vals:
- return vals[-1] # take the last in case there are multiple
- return ""
-
- ref_answer = _get_ref_answer(conv)
- # ---- load feats & coords (HDF5 with coords required) ----
- try:
- feats, coords, feat_path = load_features_and_coords(
- image_field=image_field,
- feature_root=args.feature_root,
- identifier=args.identifier,
- suffix=args.image_feature_suffix,
- sample_num=args.sample_num,
- sample_strategy=args.sample_strategy,
- rng=rng
- )
- except Exception as e:
- print(f'! skip {sid}: {e}')
- continue
-
- # ---- devices/dtypes ----
- feats = feats.cuda().to(dtype) # [T,512]
- # coords may be int/float; hook/model cast them to llm dtype before use
- coords = coords.cuda() # [T,2], keep original numerical type
-
- # ---- build text ids with splice ----
- input_ids = build_input_ids_from_prompt(tokenizer, raw_prompt).cuda().unsqueeze(0) # (1,L)
-
- # ---- vision path: token_merge → projector → (pos + perceiver) ----
- pixel_values = feats.unsqueeze(0) # (1,T,512)
- with torch.cuda.amp.autocast(enabled=(dtype in (torch.float16, torch.bfloat16)), dtype=dtype):
- pixel_values = pixel_values.to(llm.dtype)
-
- coords_rc = None # Will be computed only if needed by token_merge or perceiver
- # Check if any component requires coordinate processing
- if (token_merge and getattr(model, 'enable_token_merge', False)) or use_perceiver:
- if coords is None:
- raise ValueError("Model requires coords for token_merge or perceiver, but none were found.")
-
- coords_t = coords
- if coords_t.dim() == 2:
- coords_rc_unmapped = coords_t
- elif coords_t.dim() == 3:
- if coords_t.size(0) != 1:
- raise NotImplementedError("Batch coords >1 not supported here.")
- coords_rc_unmapped = coords_t[0]
- else:
- raise ValueError("coords must be [L,2] or [B,L,2].")
- if coords_rc_unmapped.size(-1) != 2:
- raise ValueError("coords last dim must be 2.")
-
- if not hasattr(model, '_coords_to_rowcol'):
- raise AttributeError("Model is missing `_coords_to_rowcol` required for token_merge/perceiver.")
- coords_rc = model._coords_to_rowcol(coords_rc_unmapped)
-
- # Optional Token Merge (pre-projector)
- if token_merge is not None and getattr(model, 'enable_token_merge', False):
- if coords_rc is None:
- raise ValueError("Token merge is enabled but coords were not processed.")
- padmask = torch.zeros((pixel_values.size(0), pixel_values.size(1)), dtype=torch.bool, device=pixel_values.device)
- pixel_values, coords_rc, _ = token_merge(x=pixel_values, coords_rc=coords_rc, padmask=padmask)
-
- # Projector (always runs, on raw or merged tokens)
- pixel_values = projector(pixel_values)
-
- # Optional Perceiver with Positional Embeddings (post-projector)
- if use_perceiver:
- if coords_rc is None:
- raise ValueError("Perceiver is enabled but coords were not processed.")
-
- text_emb = llm.get_input_embeddings()(input_ids.clamp(min=0)).to(llm.dtype).detach()
-
- if not hasattr(model, '_coords_rc_to_pos'):
- raise AttributeError("Model missing `_coords_rc_to_pos` required for perceiver PE.")
- pos = model._coords_rc_to_pos(coords_rc, int(slide_ngrids))
-
- if (pos_embed is not None) and (pe_gate is not None) and (pe_drop is not None):
- pixel_values = pixel_values + pe_drop(pos_embed[:, pos, :].squeeze(0).to(pixel_values.dtype) * pe_gate)
-
- pixel_values = perceiver(
- text_embeddings=text_emb,
- attention_mask=None,
- visual_tokens=pixel_values.to(llm.dtype),
- )
-
- # ---- pack & generate ----
- mm_inputs = prepare_inputs_labels_for_multimodal(llm=llm, input_ids=input_ids, pixel_values=pixel_values)
- torch.cuda.synchronize(); t0 = time.time()
- out_ids = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria
- )
- torch.cuda.synchronize(); dt = time.time() - t0
-
- out_text = tokenizer.decode(out_ids[0])
- # common Qwen chat tail
- if out_text.endswith('<|im_end|>'):
- out_text = out_text[:-10]
-
- print(f'[{sid}] {dt:.3f}s → {out_text[:120].strip()}{"..." if len(out_text) > 120 else ""}')
- times.append(dt)
-
- rec = dict(
- file=osp.basename(jfpath),
- id=sid,
- image=image_field if isinstance(image_field, str) else image_field[0],
- feature_path=feat_path,
- prompt=raw_prompt,
- output=out_text,
- ref_answer=ref_answer,
- gen_time_sec=dt
- )
- rows.append(rec)
- if jf:
- jf.write(json.dumps(rec, ensure_ascii=False) + '\n')
-
- if args.out_jsonl and jf:
- jf.close()
-
- pd.DataFrame(rows).to_csv(args.out_csv, index=False)
- if times:
- print(f'\nAverage generation time over {len(times)} samples: {np.mean(times):.4f}s')
- print(f'Wrote: {args.out_csv}')
- if args.out_jsonl:
- print(f'Wrote: {args.out_jsonl}')
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_fusion_compressor.py b/code/xtuner/tools/test_fusion_compressor.py
deleted file mode 100644
index 07778148f33d59820dea8b374f2648e7b75bc4a0..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_fusion_compressor.py
+++ /dev/null
@@ -1,519 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import deepspeed
-import time
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from sympy import im
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-import torch
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-# Import for creating the attention mask for pre-fusion layers
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from xtuner.utils import PROMPT_TEMPLATE
-from PIL import Image
-import pandas as pd
-import numpy as np
-from transformers import GenerationConfig, StoppingCriteriaList
-
-import os
-
-# <<< ADDED: Imports for visualization >>>
-import matplotlib.pyplot as plt
-import seaborn as sns
-from sklearn.manifold import TSNE
-# <<< END ADDED >>>
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='test_output_csv')
- parser.add_argument(
- '--eval_output_path',
- default='slidechat_baseline_eval.txt',
- help='path to save evaluation results')
-
- parser.add_argument(
- '--torch-dtype',
- default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--work-dir',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-def visualize_tokens(text_tokens, visual_tokens, compressed_tokens, case_name, save_path='token_visualization.png'):
- """
- Applies t-SNE to visualize the different token types and saves the plot.
- """
- # 1. Prepare data: move to CPU, detach, convert to numpy, and remove batch dim
- text_np = text_tokens.squeeze(0).detach().float().cpu().numpy()
- visual_np = visual_tokens.squeeze(0).detach().float().cpu().numpy()
- compressed_np = compressed_tokens.squeeze(0).detach().float().cpu().numpy()
-
- # 2. Create labels for each token type
- labels = ['Text'] * text_np.shape[0] + \
- ['Visual'] * visual_np.shape[0] + \
- ['Compresseion'] * compressed_np.shape[0]
-
- # 3. Combine all tokens into one array
- all_tokens = np.concatenate((text_np, visual_np, compressed_np), axis=0)
-
- # 4. Apply t-SNE
- print("Running t-SNE... this may take a moment.")
- tsne = TSNE(n_components=2, perplexity=30, max_iter=1000, random_state=42, init='pca', learning_rate='auto')
- tokens_2d = tsne.fit_transform(all_tokens)
-
- # 5. Create a DataFrame for plotting
- df_plot = pd.DataFrame({
- 'x': tokens_2d[:, 0],
- 'y': tokens_2d[:, 1],
- 'type': labels
- })
-
- # 6. Plot using seaborn
- plt.figure(figsize=(12, 10))
- sns.scatterplot(
- data=df_plot,
- x='x',
- y='y',
- hue='type',
- palette={'Text': 'blue', 'Visual': 'green', 'Compresseion': 'red'},
- alpha=0.7,
- s=50 # marker size
- )
- # plt.title(f't-SNE Visualization of Token Types for {case_name}') #<-- REMOVED
- plt.xlabel('t-SNE Dimension 1', fontsize=16) #<-- MODIFIED
- plt.ylabel('t-SNE Dimension 2', fontsize=16) #<-- MODIFIED
- plt.legend(title='Token Type', fontsize=14, title_fontsize=16) #<-- MODIFIED
- plt.tight_layout()
- plt.savefig(save_path)
- plt.close()
- print(f"Visualization saved to {save_path}")
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- # build the runner from config
- if 'runner_type' not in cfg:
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- model_kwargs = {
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
-
- state_dict = guess_load_checkpoint(args.checkpoint)
- print(f'available keys in checkpoint: {state_dict.keys()}')
- runner.model.load_state_dict(state_dict, strict=False)
-
- missing_keys, unexpected_keys = runner.model.load_state_dict(state_dict, strict=False).missing_keys, \
- runner.model.load_state_dict(state_dict, strict=False).unexpected_keys
-
- print("✅ Missing keys (not in checkpoint):")
- for key in missing_keys:
- print(f" - {key}")
-
- print("\n⚠️ Unexpected keys (in checkpoint but not in model):")
- for key in unexpected_keys:
- print(f" - {key}")
-
- image_only = runner.model.image_only if hasattr(runner.model, 'image_only') else False
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
-
- llm = runner.model.llm
- llm.eval()
-
- LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype'])
- LongNet_encoder.cuda()
- LongNet_encoder.eval()
-
- projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- projector.cuda()
- projector.eval()
-
- # Load prefusion layers if they exist in the model
- prefusion_layers = None
- prefusion_layer_num = 0
- if hasattr(runner.model, 'prefusion_layers') and runner.model.prefusion_layer_num > 0:
- prefusion_layer_num = runner.model.prefusion_layer_num
- print(f"Found {prefusion_layer_num} prefusion layers. Loading them.")
- prefusion_layers = runner.model.prefusion_layers.to(model_kwargs['torch_dtype'])
- prefusion_layers.cuda()
- prefusion_layers.eval()
- if hasattr(runner.model, 'query_emb'):
- print('Loading query embeddings...')
- query_emb = runner.model.query_emb.to(model_kwargs['torch_dtype'])
- query_emb.cuda()
- else:
- print("No prefusion layers found in the model.")
-
-
- df_test_case = pd.read_csv(args.test_slide_csv)
- df_test_case['Output'] = df_test_case.apply(lambda x: '', axis=1)
- columns = ['ID','Slide','Tumor','Broad Category','Narrow Category','Question','A','B','C','D','Answer','Output', 'GenerationTime']
- df_test_output = pd.DataFrame(columns=columns)
- generation_times = []
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- # A flag to ensure visualization is only done once
- visualization_done = False
-
- for i in range(df_test_case.shape[0]):
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- print('tumor type: ', args.tumor_type)
- if df_test_case.loc[i, 'Tumor'] != args.tumor_type:
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
- tumor_name = df_test_case.loc[i, 'Tumor']
- case_name = df_test_case.loc[i, 'Slide']
-
- test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/' + str( tumor_name.lower() ) + '_224x224_b20_t15/pt_files/' + case_name + '.pt'
-
- if not os.path.exists(test_image_file):
- with open("/data/qingq/PathVLM/baselines/github/SlideChat/outputs/missing_WSI_log.txt", "w") as f:
- f.write(test_image_file + "\n")
- f.close()
- continue
-
- if test_image_file.endswith('.pt'):
- image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- image = image.numpy()
- total_rows = image.shape[0]
- print('before sampling image shape', image.shape)
- sample_num = 10000
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_image = image[indices]
- image = sampled_image[:sample_num]
-
- image = torch.from_numpy(image.reshape(1, -1, 512))
- print('final image shape', image.shape)
- else:
- image = Image.open(test_image_file).convert('RGB')
-
- image = image.cuda()
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = df_test_case.loc[i, 'Question']
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- options.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(options)
-
- sample_input = f"{question}\n{options_str}"
- print('Input: ', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids).cuda()
-
- image = image.to(projector.dtype)
- feat_to_proj = image.permute(1, 0, 2)
- long_net_output = LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj)["encoder_out"]
- feat_to_proj = long_net_output.permute(1, 0, 2)
-
- mm_inputs_kwargs = {}
-
- if prefusion_layer_num > 0 and prefusion_layers is not None:
- print('Applying pre-fusion layers...')
- projected_global_image_features = projector(feat_to_proj)
- print(f'Global visual features shape: {projected_global_image_features.shape}')
-
- batched_input_ids = input_ids.unsqueeze(0)
- padding_mask = (batched_input_ids <= 0)
-
- B, D, T_visual = projected_global_image_features.shape
- Q = query_emb.unsqueeze(0).expand(B, -1, -1)
-
- text_embeddings = llm.get_input_embeddings()(batched_input_ids.clamp(min=0)).detach()
- text_mask = ~padding_mask
-
- if image_only:
- x = torch.cat([projected_global_image_features, Q], dim=1)
- mask = torch.cat((
- torch.zeros((padding_mask.size(0),projected_global_image_features.size(1)),
- device=padding_mask.device).bool(),
- torch.ones(padding_mask.size(0), Q.size(1),
- device = padding_mask.device).bool()
- ),
- dim=1)
- else:
- x = torch.cat([text_embeddings, projected_global_image_features, Q], dim=1)
- mask=torch.cat((text_mask,
- torch.zeros((padding_mask.size(0),projected_global_image_features.size(1)),
- device=padding_mask.device).bool(),
- torch.ones(padding_mask.size(0), Q.size(1),
- device = padding_mask.device).bool()
- ),
- dim=1)
-
- if getattr(llm, "_use_flash_attention_2", False) or \
- getattr(llm.config, "_attn_implementation", "") == "flash_attention_2":
- attention_mask = (~mask).int()
- else:
- attention_mask =_prepare_4d_causal_attention_mask(~mask, (x.size(0), x.size(1)), x, 0)
-
- position_ids = (~mask).int().long().cumsum(-1) - 1
- position_ids.masked_fill_((~mask).int() == 0, 1)
-
- for layer in prefusion_layers:
- x = layer(
- x,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=False,
- )[0]
-
- pixel_values=x[:,-1*runner.model.compressor_grid_size:,:]
-
- # <<< ADDED: Visualization Logic >>>
- if not visualization_done:
- try:
- output_dir = "/data/qingq/PathVLM/baselines/github/SlideChat/outputs/visualizations"
- os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists
- vis_save_path = os.path.join(output_dir, f'{case_name}_{tumor_name}_token_visualization.png')
-
- visualize_tokens(
- text_tokens=text_embeddings,
- visual_tokens=projected_global_image_features,
- compressed_tokens=pixel_values,
- case_name=case_name,
- save_path=vis_save_path
- )
- visualization_done = True # Set flag to true after one successful visualization
- except Exception as e:
- print(f"Could not generate visualization for case {case_name}: {e}")
- # <<< END of Visualization Logic >>>
-
- else:
- pixel_values = projector(feat_to_proj)
-
-
- # Final preparation of inputs for the LLM
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=runner.model.llm,
- input_ids=input_ids.unsqueeze(0),
- pixel_values=pixel_values,
- **mm_inputs_kwargs)
-
- max_new_tokens=500
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words=[]
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(
- StopWordStoppingCriteria(tokenizer, word))
-
- torch.cuda.synchronize()
- start_time = time.time()
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- torch.cuda.synchronize()
- end_time = time.time()
- duration = end_time - start_time
- generation_times.append(duration)
-
- generation_output = tokenizer.decode(generate_output[0])
- if generation_output.endswith('<|im_end|>'):
- generation_output = generation_output[:-10]
-
- print('Output: ', generation_output)
- print(f'Generation Time: {duration:.4f} seconds')
-
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'GenerationTime': duration
- }
- df_test_output.loc[i] = add_row
- if args.test_output_csv:
- df_test_output.to_csv(args.test_output_csv, index=False)
-
- torch.cuda.empty_cache()
- if prefusion_layer_num > 0 and prefusion_layers is not None:
- del x
- del attention_mask
- del position_ids
-
-
- print('Test ok!')
-
- if generation_times:
- average_time = np.mean(generation_times)
- print(f"\nAverage Generation Time over {len(generation_times)} samples: {average_time:.4f} seconds")
-
-
- def slidechat_performance(output_file_path, eval_output_path):
- df = pd.read_csv(output_file_path)
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
- df['correct'] = df['Answer_clean'] == df['Output_clean']
- accuracy = df['correct'].mean()
- average_gen_time = df['GenerationTime'].mean()
- total = len(df)
- correct = df['correct'].sum()
- print(f"Exact Match Accuracy: {accuracy:.2%} ({correct}/{total})")
- print(f"Average Generation Time: {average_gen_time:.4f} seconds")
-
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}
- Average Generation Time : {average_gen_time:.4f} seconds
- """
-
- print(output_file_path)
- print(result_text)
-
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
- if args.test_output_csv:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_llm_only.py b/code/xtuner/tools/test_llm_only.py
deleted file mode 100644
index ceb23417cf755f38d83babe7da03e1a0842abb3a..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_llm_only.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-# MODIFIED: This script has been altered to test the LLM part ONLY.
-# All vision-related components (image loading, encoders, projectors) have been removed.
-
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import deepspeed
-import time # <<< ADDED: Import the time module
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-# from sympy import im # <<< REMOVED: Unused import
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (
- # DEFAULT_IMAGE_TOKEN, # <<< REMOVED: Not needed for text-only
- # IMAGE_TOKEN_INDEX, # <<< REMOVED: Not needed for text-only
- StopWordStoppingCriteria)
-import torch
-# from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal # <<< REMOVED: Multimodal function not needed
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- # BitsAndBytesConfig, # <<< REMOVED: Unused import
- # CLIPImageProcessor, # <<< REMOVED: Vision model not needed
- # CLIPVisionModel, # <<< REMOVED: Vision model not needed
- GenerationConfig)
-from xtuner.utils import PROMPT_TEMPLATE
-# from PIL import Image # <<< REMOVED: Image library not needed
-import pandas as pd
-import numpy as np
-from transformers import GenerationConfig, StoppingCriteriaList
-
-import os
-# from xtuner.model.llava_dim_reducer import TextGuidedVisualTokenAttentionReducer # <<< REMOVED: Vision component
-
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='test_output_csv')
- parser.add_argument(
- '--eval_output_path',
- default='slidechat_baseline_eval.txt',
- help='path to save evaluation results')
-
- parser.add_argument(
- '--torch-dtype',
- default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--work-dir',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- # parser.add_argument('--divprune_ratio', type = float, default = None, help='the ratio for divprune') # <<< REMOVED: Vision-related argument
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-# <<< REMOVED: DivPrune and related helper functions are not needed for text-only evaluation
-# def pairwise_l1_distance(matrix: torch.Tensor) -> torch.Tensor:
-# # ...
-#
-# def pairwise_cosine_similarity(matrix):
-# # ...
-#
-# def DivPrune(visual_feature_vectors, image_feature_length,
-# cosine_matrix=None, threshold_ratio=0.1):
-# # ...
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- # build the runner from config
- if 'runner_type' not in cfg:
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- # model_kwargs = {
- # 'trust_remote_code': True,
- # 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- # } # <<< REMOVED: Moved dtype to llm definition
-
- state_dict = guess_load_checkpoint(args.checkpoint)
- print(f'available keys in checkpoint: {state_dict.keys()}')
- runner.model.load_state_dict(state_dict, strict=False)
-
-
- ##############################qingq check loaded weights######################################
- missing_keys, unexpected_keys = runner.model.load_state_dict(state_dict, strict=False).missing_keys, \
- runner.model.load_state_dict(state_dict, strict=False).unexpected_keys
-
- print("✅ Missing keys (not in checkpoint):")
- for key in missing_keys:
- print(f" - {key}")
-
- print("\n⚠️ Unexpected keys (in checkpoint but not in model):")
- for key in unexpected_keys:
- print(f" - {key}")
- ##############################qingq check loaded weights######################################
-
-
-
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
-
- llm = runner.model.llm # dtype: float16
- llm.to(TORCH_DTYPE_MAP[args.torch_dtype]) # <<< MODIFIED: Ensure LLM has the correct dtype
- llm.eval()
-
- # <<< REMOVED: All vision model components are no longer needed.
- # LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype'])
- # projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- # visual_token_reducer = None
- # mil = None
- # ... and their associated setup and eval calls ...
-
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- df_test_case['Output'] = df_test_case.apply(lambda x: '', axis=1)
- columns = ['ID','Slide','Tumor','Broad Category','Narrow Category','Question','A','B','C','D','Answer','Output', 'GenerationTime']
- df_test_output = pd.DataFrame(columns=columns)
- generation_times = []
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- for i in range(df_test_case.shape[0]):
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- print('tumor type: ', args.tumor_type)
- if df_test_case.loc[i, 'Tumor'] != args.tumor_type:
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
-
- # <<< REMOVED: All image loading and processing logic has been removed.
- # test_image_file = ...
- # if not os.path.exists(test_image_file): ...
- # if test_image_file.endswith('.csv'): ...
- # elif test_image_file.endswith('.pt'): ...
- # else: image = Image.open(...)
- # image = image.cuda()
-
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = df_test_case.loc[i, 'Question']
- try:
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- options.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(options)
-
- sample_input = f"{question}\n{options_str}"
- print('Input: ', sample_input)
- except KeyError as e:
- sample_input = question
- print('Input: ', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
-
- # <<< MODIFIED: Input construction for text-only.
- # No image token placeholders are needed.
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
-
- # Tokenize the text-only input
- input_ids = tokenizer(inputs, return_tensors='pt').input_ids.cuda()
-
-
- # <<< REMOVED: The entire multimodal input preparation block is no longer necessary.
- # chunk_encode = [] ...
- # input_ids.append(IMAGE_TOKEN_INDEX) ...
- # image = image.to(projector.dtype) ...
- # pixel_values = projector(image) ...
- # divprune logic ...
- # visual_token_reducer logic ...
- # mm_inputs = prepare_inputs_labels_for_multimodal(...)
-
- max_new_tokens=4096
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=True,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words=[]
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(
- StopWordStoppingCriteria(tokenizer, word))
-
- torch.cuda.synchronize()
- start_time = time.time()
-
- # <<< MODIFIED: Call generate with text `input_ids` directly.
- generate_output = llm.generate(
- input_ids,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- torch.cuda.synchronize()
- end_time = time.time()
- duration = end_time - start_time
- generation_times.append(duration)
-
- # The rest of the loop remains the same
- generation_output = tokenizer.decode(generate_output[0][input_ids.shape[1]:]) # Decode only the generated tokens
- if generation_output.endswith('<|im_end|>'):
- generation_output = generation_output[:-10]
-
- print('Output: ', generation_output)
- print(f'Generation Time: {duration:.4f} seconds')
- try:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'GenerationTime': duration
- }
- except:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'GenerationTime': duration
- }
- df_test_output.loc[i] = add_row
- if args.test_output_csv:
- df_test_output.to_csv(args.test_output_csv, index=False)
-
- print('Test ok!')
-
- if generation_times:
- average_time = np.mean(generation_times)
- print(f"\nAverage Generation Time over {len(generation_times)} samples: {average_time:.4f} seconds")
-
-
- # check performance
- def slidechat_performance(output_file_path, eval_output_path):
-
- # Load the CSV
- df = pd.read_csv(output_file_path)
-
-
- # Clean ground-truth answers
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
-
- # Extract the letter before the period in 'Output' (e.g., 'A. Luminal A' → 'A')
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
-
- # Compute exact match
- df['correct'] = df['Answer_clean'] == df['Output_clean']
-
- # Calculate accuracy
- accuracy = df['correct'].mean()
-
- average_gen_time = df['GenerationTime'].mean()
-
- # Print summary
- total = len(df)
- correct = df['correct'].sum()
- print(f"Exact Match Accuracy: {accuracy:.2%} ({correct}/{total})")
- print(f"Average Generation Time: {average_gen_time:.4f} seconds")
-
-
- # Build the result string
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}
- Average Generation Time : {average_gen_time:.4f} seconds
- """
-
- # Print to console
- print(output_file_path)
- print(result_text)
-
- # Save to txt file
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
- if args.test_output_csv:
- try:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
- except Exception as e:
- print(f"Error during performance evaluation: {e}")
- pass
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_mil.py b/code/xtuner/tools/test_mil.py
deleted file mode 100644
index 7c592bd4ad6fceca89a61096295017ba9b1fbfda..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_mil.py
+++ /dev/null
@@ -1,507 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import deepspeed
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from sympy import im
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-import torch
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-from xtuner.utils import PROMPT_TEMPLATE
-from PIL import Image
-import pandas as pd
-import numpy as np
-from transformers import GenerationConfig, StoppingCriteriaList
-from model.architecture.transformer import ACMIL_GA_NoClassifier, ACMIL_MHA_NoClassifier
-
-import os
-from xtuner.model.llava_dim_reducer import TextGuidedVisualTokenAttentionReducer
-
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='test_output_csv')
- parser.add_argument(
- '--eval_output_path',
- default='slidechat_baseline_eval.txt',
- help='path to save evaluation results')
-
- parser.add_argument(
- '--torch-dtype',
- default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--work-dir',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument('--divprune_ratio', type = float, default = None, help='the ratio for divprune')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-def pairwise_l1_distance(matrix: torch.Tensor) -> torch.Tensor:
- """
- Compute the full pairwise L1 (Manhattan) distance matrix
- for an [N, D] tensor.
- """
- # torch.cdist with p=1 computes L1 distance
- return torch.cdist(matrix, matrix, p=1)
-
-def pairwise_cosine_similarity(matrix):
- norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
- cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
- return cosine_similarity
-
-
-def DivPrune(visual_feature_vectors, image_feature_length,
- cosine_matrix=None, threshold_ratio=0.1):
- threshold_terms = int(round(threshold_ratio * image_feature_length))
- if cosine_matrix is None:
- cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))
-
- s = torch.empty(threshold_terms, dtype=torch.long, device=visual_feature_vectors.device)
- for i in range(threshold_terms):
- if i == 0:
- m2 = cosine_matrix
- else:
- m2 = torch.index_select(cosine_matrix, 0, torch.index_select(s, 0, torch.arange(0, i, device=cosine_matrix.device)))
-
- if i == 0:
- scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
- else:
- scores = torch.min(m2, dim=0).values
-
- phrase_to_add_idx = torch.argmax(scores)
- s[i] = phrase_to_add_idx
- return s, cosine_matrix
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- # build the runner from config
- if 'runner_type' not in cfg:
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- model_kwargs = {
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
-
- state_dict = guess_load_checkpoint(args.checkpoint)
- # state_dict = torch.load(args.checkpoint, map_location='cpu')
- print(f'available keys in checkpoint: {state_dict.keys()}')
- runner.model.load_state_dict(state_dict, strict=False)
-
-
- ##############################qingq check loaded weights######################################
- missing_keys, unexpected_keys = runner.model.load_state_dict(state_dict, strict=False).missing_keys, \
- runner.model.load_state_dict(state_dict, strict=False).unexpected_keys
-
- print("✅ Missing keys (not in checkpoint):")
- for key in missing_keys:
- print(f" - {key}")
-
- print("\n⚠️ Unexpected keys (in checkpoint but not in model):")
- for key in unexpected_keys:
- print(f" - {key}")
- ##############################qingq check loaded weights######################################
-
-
-
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
-
- llm = runner.model.llm # dtype: float16
- llm.eval()
-
- LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype']) # torch.bfloat16
- LongNet_encoder.cuda()
- LongNet_encoder.eval()
-
- projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- projector.cuda()
- projector.eval()
-
- acmil = runner.model.acmil.to(model_kwargs['torch_dtype'])
- acmil.cuda()
- acmil.eval()
-
-
-
- # Check for and apply the visual token reducer
- visual_token_reducer = None
- if hasattr(runner.model, 'visual_token_reducer'):
- print("Visual token reducer found, applying it.")
- visual_token_reducer = runner.model.visual_token_reducer.to(model_kwargs['torch_dtype'])
- visual_token_reducer.cuda()
- visual_token_reducer.eval()
- # projector = torch.nn.Sequential(projector, visual_token_reducer)
- # print('Using visual token reducer')
-
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- df_test_case['Output'] = df_test_case.apply(lambda x: '', axis=1)
- columns = ['ID','Slide','Tumor','Broad Category','Narrow Category','Question','A','B','C','D','Answer','Output']
- df_test_output = pd.DataFrame(columns=columns)
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- for i in range(df_test_case.shape[0]):
- # if i != 541:
- # continue
-
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- # only check the brca
- print('tumor type: ', args.tumor_type)
- if df_test_case.loc[i, 'Tumor'] != args.tumor_type: #'LUAD':
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
- tumor_name = df_test_case.loc[i, 'Tumor']
- case_name = df_test_case.loc[i, 'Slide']
-
- # test_image_file = "TCGA_patch_feat/" + df_test_case.loc[i, 'Tumor'] + "/" + case_name + ".csv"
- test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/' + str( tumor_name.lower() ) + '_224x224_b20_t15/pt_files/' + case_name + '.pt'
-
- # test_image_file = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A0CJ-01Z-00-DX2.csv'
- # test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/pt_files/TCGA-A7-A0CJ-01Z-00-DX2.pt'
-
-
- # test_image_file = '/data/qingq/PathVLM/baselines/github/SlideChat/dataset/WSI_feat_sample/TCGA-A7-A6VV-01Z-00-DX2.csv'
- # test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/brca_224x224_b20_t15/pt_files/TCGA-A7-A6VV-01Z-00-DX2.pt'
-
- # for some missing files, skip
- if not os.path.exists(test_image_file):
- with open("/data/qingq/PathVLM/baselines/github/SlideChat/outputs/missing_WSI_log.txt", "w") as f: # use "a" to append instead of overwrite
- f.write(test_image_file + "\n")
- f.close()
- continue
-
- if test_image_file.endswith('.csv'):
- image = pd.read_csv(test_image_file) # shape: [num_patches, 513]
- image = image.iloc[:, :512]
- total_rows = image.shape[0]
- sample_num = 38400
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_df = image.iloc[indices]
- image = sampled_df.iloc[:sample_num]
- image = image.to_numpy().reshape(1, image.shape[0], 512) # (1, N, 512)
- image = torch.from_numpy(image)
-
- # qingq modify, our feature format is .pt file
- elif test_image_file.endswith('.pt'):
- image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- image = image.numpy()
- total_rows = image.shape[0]
- print('before sampling image shape', image.shape)
- sample_num = 10000 # 38400. original 38400 is out of memory for 45G
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_image = image[indices]
- image = sampled_image[:sample_num]
-
- # Reshape and convert to tensor: (1, N, 512)
- image = torch.from_numpy(image.reshape(1, -1, 512)) # final shape: (1, N, 512)
- print('final image shape', image.shape)
-
- else:
- image = Image.open(test_image_file).convert('RGB')
-
- image = image.cuda() # shape (1, patch_num, 512)
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = df_test_case.loc[i, 'Question']
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- options.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(options)
-
- sample_input = f"{question}\n{options_str}"
- print('Input: ', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids).cuda()
-
-
- # image.dtype: float32
- # runner.model.dtype: float16
- # image = image.to(runner.model.projector.dtype)
- # image = runner.model.LongNet_encoder(src_tokens=None, token_embeddings=image.permute(1, 0, 2).to(runner.model.llm.dtype))["encoder_out"]
-
-
- # model_kwargs['torch_dtype']
- image = image.to(projector.dtype) # shape (1, patch_num, 512)
- image = LongNet_encoder(src_tokens=None, token_embeddings=image.permute(1, 0, 2))["encoder_out"] # output shape (patch_num, 1, 512)
-
- image = image.permute(1, 0, 2) # shape: [1, patch_num, 512]
-
- # pixel_values = runner.model.projector(image)
- pixel_values = projector(image) # shape: [1, patch_num, 3584]
-
- if args.divprune_ratio is not None and args.divprune_ratio > 0 and args.divprune_ratio < 1.0:
-
- print('Applying divprune with ratio:', args.divprune_ratio)
- # Apply divprune
- pruned_batch_features = []
- for visual_tokens in pixel_values: # Iterate over the batch dimension
- img_feature_len = visual_tokens.shape[0]
- selected_indices, _ = DivPrune(
- visual_tokens,
- img_feature_len,
- threshold_ratio=args.divprune_ratio
- )
- selected_indices = torch.sort(selected_indices).values
- pruned_features = visual_tokens[selected_indices]
- pruned_batch_features.append(pruned_features)
-
- # Stack the list of pruned tensors back into a single batch tensor
- pixel_values = torch.stack(pruned_batch_features, dim=0)
- print('After divprune, pixel_values shape: ', pixel_values.shape)
-
- if visual_token_reducer is not None:
- is_text_guided_reducer = isinstance(
- visual_token_reducer, TextGuidedVisualTokenAttentionReducer
- )
- print("Applying visual token reducer")
-
-
- if is_text_guided_reducer:
- # Get text embeddings and attention mask for the guided reducer
- # input_ids = data['input_ids']
- input_ids = input_ids.unsqueeze(0) # Add batch dimension
- text_attention_mask = None
-
- text_embeddings = llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
-
-
- pixel_values = visual_token_reducer(
- pixel_values, text_embeddings, text_attention_mask
- )
- else:
- # Input to reducer is now (B, T, D)
- pixel_values = visual_token_reducer(pixel_values)
-
-
- print('After visual token reducer, pixel_values shape: ', pixel_values.shape)
-
-
-
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=runner.model.llm,
- input_ids=input_ids.unsqueeze(0),
- pixel_values=pixel_values)
-
- max_new_tokens=500
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words=[]
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(
- StopWordStoppingCriteria(tokenizer, word))
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- generation_output = tokenizer.decode(generate_output[0])
- if generation_output.endswith('<|im_end|>'):
- generation_output = generation_output[:-10]
-
- print('Output: ', generation_output)
-
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output
- }
- df_test_output.loc[i] = add_row
- if args.test_output_csv:
- df_test_output.to_csv(args.test_output_csv)
-
- print('Test ok!')
-
-
- # check performance
-
-
- def slidechat_performance(output_file_path, eval_output_path):
-
- # Load the CSV
- df = pd.read_csv(output_file_path)
-
-
- # Clean ground-truth answers
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
-
- # Extract the letter before the period in 'Output' (e.g., 'A. Luminal A' → 'A')
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
-
- # Compute exact match
- df['correct'] = df['Answer_clean'] == df['Output_clean']
-
- # Calculate accuracy
- accuracy = df['correct'].mean()
-
- # Print summary
- total = len(df)
- correct = df['correct'].sum()
- print(f"Exact Match Accuracy: {accuracy:.2%} ({correct}/{total})")
-
-
- # Build the result string
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}
- """
-
- # Print to console
- print(output_file_path)
- print(result_text)
-
- # Save to txt file
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
- if args.test_output_csv:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_random.py b/code/xtuner/tools/test_random.py
deleted file mode 100644
index 3cf3db752a51113e3bbce4bf58f1f25bcdcfd752..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_random.py
+++ /dev/null
@@ -1,593 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import time
-import deepspeed
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint, LoadWoInit, prepare_inputs_labels_for_multimodal
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria, PROMPT_TEMPLATE)
-
-import torch
-from transformers import (AutoTokenizer, GenerationConfig)
-from PIL import Image
-import pandas as pd
-import numpy as np
-from transformers import StoppingCriteriaList
-
-from xtuner.model.llava_dim_reducer import TextGuidedVisualTokenAttentionReducer
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto'
-)
-
-def _append_safe_csv(df: pd.DataFrame, path: str, append: bool = True):
- """Append df to CSV at path, writing header only if file doesn't exist or is empty."""
- if not path or df is None or df.empty:
- return
- out_dir = os.path.dirname(path)
- if out_dir and not os.path.exists(out_dir):
- os.makedirs(out_dir, exist_ok=True)
- exists = os.path.exists(path)
- empty = (not exists) or (os.path.getsize(path) == 0)
- mode = 'a' if append and exists and not empty else 'w'
- header = empty or not append
- df.to_csv(path, index=False, mode=mode, header=header)
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='restrict sampling to this tumor for fixed-question mode')
- parser.add_argument('--eval_output_path', default='slidechat_baseline_eval.txt', help='path to save evaluation results')
-
- parser.add_argument('--torch-dtype', default='bf16', choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under a specific `dtype`.')
-
- parser.add_argument('--work-dir', help='the directory to save the file containing evaluation metrics')
- parser.add_argument('--cfg-options', nargs='+', action=DictAction,
- help='override some settings in the used config')
-
- parser.add_argument('--divprune_ratio', type=float, default=None, help='the ratio for divprune')
-
- parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none', help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
-
- # NEW: fixed-question + random visual mode
- parser.add_argument('--fixed_question_id', type=int, default=None,
- help='If set, use this single question and sample random visuals K times.')
- parser.add_argument('--n_trials', type=int, default=20,
- help='Number of random-visual trials for the fixed question mode.')
- parser.add_argument('--random_seed', type=int, default=42,
- help='Random seed for reproducibility of random visual sampling.')
-
- # NEW: multi-run accumulation
- parser.add_argument('--append', action='store_true',
- help='Append to an existing CSV instead of overwriting.')
- parser.add_argument('--run_tag', default=None,
- help='Optional string to identify this run (e.g., "seed42_div0.2").')
- parser.add_argument('--max_new_tokens', type=int, default=1000,
- help='Maximum number of new tokens to generate in each run.')
-
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in cfg_dict.items():
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-
-def pairwise_cosine_similarity(matrix):
- norm_matrix = matrix / matrix.norm(dim=1, keepdim=True)
- cosine_similarity = torch.mm(norm_matrix, norm_matrix.t())
- return cosine_similarity
-
-
-def DivPrune(visual_feature_vectors, image_feature_length,
- cosine_matrix=None, threshold_ratio=0.1):
- threshold_terms = int(round(threshold_ratio * image_feature_length))
- if threshold_terms <= 0:
- return torch.arange(0, 0, device=visual_feature_vectors.device), None
- if cosine_matrix is None:
- cosine_matrix = 1.0 - (pairwise_cosine_similarity(visual_feature_vectors))
-
- s = torch.empty(threshold_terms, dtype=torch.long, device=visual_feature_vectors.device)
- for i in range(threshold_terms):
- if i == 0:
- m2 = cosine_matrix
- else:
- m2 = torch.index_select(
- cosine_matrix, 0,
- torch.index_select(s, 0, torch.arange(0, i, device=cosine_matrix.device))
- )
- if i == 0:
- scores = torch.topk(m2, 2, dim=0, largest=False).values[1, :]
- else:
- scores = torch.min(m2, dim=0).values
- phrase_to_add_idx = torch.argmax(scores)
- s[i] = phrase_to_add_idx
- return s, cosine_matrix
-
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- register_function(cfg._cfg_dict)
-
- # work_dir
- if args.work_dir is not None:
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
-
- # runner
- if 'runner_type' not in cfg:
- runner = Runner.from_cfg(cfg)
- else:
- runner = RUNNERS.build(cfg)
-
- model_kwargs = {'trust_remote_code': True, 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]}
-
- # load weights
- state_dict = guess_load_checkpoint(args.checkpoint)
- print(f'available keys in checkpoint: {list(state_dict.keys())[:5]} ...')
- missing_unexp = runner.model.load_state_dict(state_dict, strict=False)
- print("✅ Missing keys:", missing_unexp.missing_keys)
- print("⚠️ Unexpected keys:", missing_unexp.unexpected_keys)
-
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True
- )
-
- llm = runner.model.llm
- llm.eval()
-
- LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype'])
- LongNet_encoder.cuda().eval()
-
- projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- projector.cuda().eval()
-
- # optional modules
- visual_token_reducer = None
- if hasattr(runner.model, 'visual_token_reducer'):
- print("Visual token reducer found.")
- visual_token_reducer = runner.model.visual_token_reducer.to(model_kwargs['torch_dtype'])
- visual_token_reducer.cuda().eval()
-
- mil = None
- if hasattr(runner.model, 'acmil'):
- print("ACMIL found.")
- mil = runner.model.acmil.to(model_kwargs['torch_dtype'])
- mil.cuda().eval()
-
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- # RNG
- rng = np.random.default_rng(args.random_seed)
-
- # run metadata
- run_meta = {
- 'RunTag': args.run_tag if args.run_tag is not None else '',
- 'Seed': args.random_seed,
- 'DivPruneRatio': args.divprune_ratio if args.divprune_ratio is not None else '',
- 'Checkpoint': os.path.basename(args.checkpoint) if args.checkpoint else '',
- 'Config': os.path.basename(args.config),
- }
-
- # helpers
- def slide_to_feat_path(tumor_name: str, slide_id: str) -> str:
- tumor_name_lc = str(tumor_name).lower()
- return f'/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/{tumor_name_lc}_224x224_b20_t15/pt_files/{slide_id}.pt'
-
- def build_prompt_from_row(row):
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = row['Question']
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if opt in row and pd.notna(row[opt]):
- options.append(f"{opt}. {row[opt]}")
- options_str = '\n'.join(options) if options else ''
- sample_input = (f"{question}\n{options_str}").strip()
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
- chunks = inputs.split(DEFAULT_IMAGE_TOKEN)
- enc = []
- for idx, chunk in enumerate(chunks):
- cur = tokenizer.encode(chunk, add_special_tokens=(idx == 0))
- enc.append(cur)
- assert len(enc) == 2
- input_ids = []
- for idx, cur in enumerate(enc):
- input_ids.extend(cur)
- if idx != len(enc) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- return torch.tensor(input_ids).cuda()
-
- @torch.no_grad()
- def load_and_project_pixel_values(test_image_file: str):
- if not os.path.exists(test_image_file):
- return None
- if test_image_file.endswith('.pt'):
- image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- if isinstance(image, torch.Tensor):
- image = image.numpy()
- total_rows = image.shape[0]
- sample_num = min(384000, total_rows) # keep modest to avoid OOM
- idxs = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- image = image[idxs]
- image = torch.from_numpy(image.reshape(1, -1, 512)) # (1, T, 512)
- else:
- raise ValueError("Expect .pt features; got: " + test_image_file)
-
- image = image.cuda().to(projector.dtype) # (1, T, 512)
-
- if mil is not None:
- _, image, _ = mil(image) # (1, T, 512)
- pixel_values = projector(image.unsqueeze(0)) # (1, T, D)
- else:
- enc = LongNet_encoder(src_tokens=None,
- token_embeddings=image.permute(1, 0, 2))["encoder_out"] # (T, 1, 512)
- enc = enc.permute(1, 0, 2) # (1, T, 512)
- pixel_values = projector(enc) # (1, T, D)
-
- if visual_token_reducer is not None:
- if isinstance(visual_token_reducer, TextGuidedVisualTokenAttentionReducer):
- return pixel_values, 'needs_text_guidance'
- else:
- pixel_values = visual_token_reducer(pixel_values)
- return pixel_values, None
-
- @torch.no_grad()
- def maybe_apply_divprune(pixel_values):
- if args.divprune_ratio is None or not (0.0 < args.divprune_ratio < 1.0):
- return pixel_values
- print('Applying divprune with ratio:', args.divprune_ratio)
- pruned_batch_features = []
- for visual_tokens in pixel_values: # (T, D) each (since batch=1, this loops once)
- img_feature_len = visual_tokens.shape[0]
- selected_indices, _ = DivPrune(
- visual_tokens, img_feature_len, threshold_ratio=args.divprune_ratio
- )
- selected_indices = torch.sort(selected_indices).values
- pruned = visual_tokens[selected_indices]
- pruned_batch_features.append(pruned)
- pixel_values = torch.stack(pruned_batch_features, dim=0)
- print('After divprune, pixel_values shape:', pixel_values.shape)
- return pixel_values
-
- @torch.no_grad()
- def generate_once(input_ids, pixel_values, needs_text_guidance=False, max_new_tokens=args.max_new_tokens):
- if needs_text_guidance and visual_token_reducer is not None:
- text_embeddings = llm.get_input_embeddings()(input_ids.clamp(min=0)).detach()
- pixel_values = visual_token_reducer(pixel_values, text_embeddings, None)
-
- pixel_values = maybe_apply_divprune(pixel_values)
-
- # FINAL image sequence length after all reductions (what the LLM actually sees)
- image_seq_len = int(pixel_values.shape[1]) # pixel_values shape: (1, T, D)
-
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=runner.model.llm,
- input_ids=input_ids.unsqueeze(0),
- pixel_values=pixel_values,
- )
-
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=True, # enables sampling per beam
- num_beams=5, # no beam search
- temperature=0.7, # controls randomness
- top_k=50,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words = PROMPT_TEMPLATE.qwen_chat.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList([StopWordStoppingCriteria(tokenizer, w) for w in stop_words])
-
- torch.cuda.synchronize()
- t0 = time.time()
- out = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria
- )
- torch.cuda.synchronize()
- dt = time.time() - t0
-
- text = tokenizer.decode(out[0])
- if text.endswith('<|im_end|>'):
- text = text[:-10]
- return text, dt, image_seq_len
-
- # ---------------- FIXED QUESTION + RANDOM VISUAL MODE ----------------
- if args.fixed_question_id is not None:
- qdf = df_test_case[df_test_case['ID'] == args.fixed_question_id]
- if qdf.empty:
- raise ValueError(f"ID {args.fixed_question_id} not found in {args.test_slide_csv}")
- qrow = qdf.iloc[0]
-
- pool_tumor = args.tumor_type if args.tumor_type else qrow['Tumor']
- pool_df = df_test_case[df_test_case['Tumor'] == pool_tumor].copy()
- if pool_df.empty:
- raise ValueError(f"No rows found for tumor '{pool_tumor}' to sample visuals.")
-
- fixed_input_ids = build_prompt_from_row(qrow)
-
- out_cols = ['Trial', 'RandomSlide', 'RandomTumor',
- 'FixedQuestionID', 'Question', 'A','B','C','D',
- 'Answer', 'Output', 'GenerationTime', 'ImageSeqLen',
- 'RunTag', 'Seed', 'DivPruneRatio', 'Checkpoint', 'Config']
- df_trials = pd.DataFrame(columns=out_cols)
-
- needs_guidance_flag = False
-
- if args.test_output_csv:
- out_dir = os.path.dirname(args.test_output_csv)
- if out_dir and not os.path.exists(out_dir):
- os.makedirs(out_dir, exist_ok=True)
-
- for t in range(args.n_trials):
- r_idx = int(rng.integers(0, len(pool_df)))
- rrow = pool_df.iloc[r_idx]
- rand_slide = rrow['Slide']
- rand_tumor = rrow['Tumor']
- feat_path = slide_to_feat_path(rand_tumor, rand_slide)
-
- loaded = load_and_project_pixel_values(feat_path)
- if loaded is None:
- print(f"[skip missing] {feat_path}")
- continue
-
- pixel_values, guidance = loaded
- if guidance == 'needs_text_guidance':
- needs_guidance_flag = True
-
- gen_text, dt, img_seq_len = generate_once(
- fixed_input_ids, pixel_values,
- needs_text_guidance=needs_guidance_flag,
- max_new_tokens=args.max_new_tokens
- )
-
- add = {
- 'Trial': t,
- 'RandomSlide': rand_slide,
- 'RandomTumor': rand_tumor,
- 'FixedQuestionID': int(qrow['ID']),
- 'Question': qrow['Question'],
- 'A': qrow.get('A', np.nan),
- 'B': qrow.get('B', np.nan),
- 'C': qrow.get('C', np.nan),
- 'D': qrow.get('D', np.nan),
- 'Answer': qrow.get('Answer', np.nan),
- 'Output': gen_text,
- 'GenerationTime': dt,
- 'ImageSeqLen': img_seq_len,
- 'RunTag': run_meta['RunTag'],
- 'Seed': run_meta['Seed'],
- 'DivPruneRatio': run_meta['DivPruneRatio'],
- 'Checkpoint': run_meta['Checkpoint'],
- 'Config': run_meta['Config'],
- }
- df_trials.loc[len(df_trials)] = add
-
- # append incrementally (survives crashes; accumulates across runs)
- if args.test_output_csv:
- _append_safe_csv(df_trials.iloc[[-1]], args.test_output_csv, append=args.append)
-
- print(f"[Trial {t:03d}] slide={rand_slide} tumor={rand_tumor}")
- print("Output:", gen_text)
- print(f"GenerationTime: {dt:.4f}s | ImageSeqLen: {img_seq_len}")
-
- print("Fixed-question random-visual evaluation done.")
- if args.test_output_csv:
- try:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
- except Exception as e:
- print(f"Eval error: {e}")
- return
- # ---------------- END FIXED QUESTION MODE ----------------
-
- # ------------ ORIGINAL PER-ROW EVAL (append to one CSV) ------------
- generation_times = []
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir, exist_ok=True)
-
- for i in range(df_test_case.shape[0]):
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- print('tumor type (filter): ', args.tumor_type)
- if args.tumor_type and df_test_case.loc[i, 'Tumor'] != args.tumor_type:
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
- tumor_name = df_test_case.loc[i, 'Tumor']
- case_name = df_test_case.loc[i, 'Slide']
-
- test_image_file = slide_to_feat_path(tumor_name, case_name)
-
- if not os.path.exists(test_image_file):
- with open("./missing_WSI_log.txt", "a") as f:
- f.write(test_image_file + "\n")
- continue
-
- # load + project
- loaded = load_and_project_pixel_values(test_image_file)
- if loaded is None:
- continue
- pixel_values, guidance = loaded
-
- prompt_input_ids = build_prompt_from_row(df_test_case.loc[i])
-
- gen_text, dt, img_seq_len = generate_once(
- prompt_input_ids, pixel_values,
- needs_text_guidance=(guidance == 'needs_text_guidance'),
- max_new_tokens=args.max_new_tokens
- )
-
- print('Output: ', gen_text)
- print(f'Generation Time: {dt:.4f} seconds | ImageSeqLen: {img_seq_len}')
- try:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': df_test_case.loc[i, 'Question'],
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': gen_text,
- 'GenerationTime': dt,
- 'ImageSeqLen': img_seq_len,
- 'RunTag': run_meta['RunTag'],
- 'Seed': run_meta['Seed'],
- 'DivPruneRatio': run_meta['DivPruneRatio'],
- 'Checkpoint': run_meta['Checkpoint'],
- 'Config': run_meta['Config'],
- }
- except Exception:
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': df_test_case.loc[i, 'Question'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': gen_text,
- 'GenerationTime': dt,
- 'ImageSeqLen': img_seq_len,
- 'RunTag': run_meta['RunTag'],
- 'Seed': run_meta['Seed'],
- 'DivPruneRatio': run_meta['DivPruneRatio'],
- 'Checkpoint': run_meta['Checkpoint'],
- 'Config': run_meta['Config'],
- }
-
- # append this row
- if args.test_output_csv:
- _append_safe_csv(pd.DataFrame([add_row]), args.test_output_csv, append=args.append)
- generation_times.append(dt)
-
- print('Test ok!')
-
- if generation_times:
- average_time = np.mean(generation_times)
- print(f"\nAverage Generation Time over {len(generation_times)} samples: {average_time:.4f} seconds")
-
- if args.test_output_csv:
- try:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
- except Exception as e:
- print(f"Error during performance evaluation: {e}")
- pass
-
-
-def slidechat_performance(output_file_path, eval_output_path):
- df = pd.read_csv(output_file_path)
- # tolerate missing cols in fixed-question mode
- if 'Answer' in df.columns and 'Output' in df.columns:
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
- df['correct'] = df['Answer_clean'] == df['Output_clean']
- accuracy = df['correct'].mean()
- else:
- accuracy = float('nan')
- df['correct'] = False
-
- average_gen_time = df['GenerationTime'].mean() if 'GenerationTime' in df.columns else float('nan')
- total = len(df)
- correct = df['correct'].sum() if 'correct' in df.columns else 0
-
- # Image sequence statistics (optional)
- if 'ImageSeqLen' in df.columns:
- img_len_avg = df['ImageSeqLen'].mean()
- img_len_min = df['ImageSeqLen'].min()
- img_len_max = df['ImageSeqLen'].max()
- else:
- img_len_avg = img_len_min = img_len_max = float('nan')
-
- print(f"Exact Match Accuracy: {accuracy:.2%}" if accuracy == accuracy else "Exact Match Accuracy: N/A")
- print(f"({correct}/{total})" if accuracy == accuracy else f"Total Samples: {total}")
- print(f"Average Generation Time: {average_gen_time:.4f} seconds" if average_gen_time == average_gen_time else "Average Generation Time: N/A")
- if img_len_avg == img_len_avg: # not NaN
- print(f"Image Seq Len (avg/min/max): {img_len_avg:.1f}/{img_len_min}/{img_len_max}")
-
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}""" if accuracy == accuracy else f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Accuracy : N/A"""
- result_text += f"""
- Average Generation Time : {average_gen_time:.4f} seconds
- Image Seq Len (avg/min/max) : {img_len_avg:.1f}/{img_len_min if img_len_avg == img_len_avg else 'N/A'}/{img_len_max if img_len_avg == img_len_avg else 'N/A'}
- """
- print(output_file_path)
- print(result_text)
-
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/test_token_compressor.py b/code/xtuner/tools/test_token_compressor.py
deleted file mode 100644
index d7e4a7d25256a8c2ebb9fb092c17a272517c0021..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/test_token_compressor.py
+++ /dev/null
@@ -1,507 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import os
-import os.path as osp
-from types import FunctionType
-import deepspeed
-import time # <-- ADDED
-
-from mmengine.config import Config, DictAction
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from sympy import im
-
-from xtuner.configs import cfgs_name_path
-from xtuner.model.utils import guess_load_checkpoint
-from xtuner.registry import MAP_FUNC
-from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
- StopWordStoppingCriteria)
-import torch
-from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
-from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig, CLIPImageProcessor,
- CLIPVisionModel, GenerationConfig)
-# Import for creating the attention mask for pre-fusion layers
-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
-from xtuner.utils import PROMPT_TEMPLATE
-from PIL import Image
-import pandas as pd
-import numpy as np
-from transformers import GenerationConfig, StoppingCriteriaList
-
-import os
-
-TORCH_DTYPE_MAP = dict(
- fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Test model')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--checkpoint', default=None, help='checkpoint file')
- parser.add_argument('--test_slide_csv', default=None, help='test_slide_csv')
- parser.add_argument('--test_output_csv', default=None, help='test_output_csv')
- parser.add_argument('--tumor_type', default=None, help='test_output_csv')
- parser.add_argument(
- '--eval_output_path',
- default='slidechat_baseline_eval.txt',
- help='path to save evaluation results')
-
- parser.add_argument(
- '--torch-dtype',
- default='bf16',
- choices=TORCH_DTYPE_MAP.keys(),
- help='Override the default `torch.dtype` and load the model under '
- 'a specific `dtype`.')
- parser.add_argument(
- '--work-dir',
- help='the directory to save the file containing evaluation metrics')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-def main():
- torch.cuda.set_device(0)
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- cfg.launcher = args.launcher
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None:
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- # build the runner from config
- if 'runner_type' not in cfg:
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- model_kwargs = {
- 'trust_remote_code': True,
- 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
- }
-
- state_dict = guess_load_checkpoint(args.checkpoint)
- # state_dict = torch.load(args.checkpoint, map_location='cpu')
- print(f'available keys in checkpoint: {state_dict.keys()}')
- runner.model.load_state_dict(state_dict, strict=False)
-
-
- ##############################qingq check loaded weights######################################
- missing_keys, unexpected_keys = runner.model.load_state_dict(state_dict, strict=False).missing_keys, \
- runner.model.load_state_dict(state_dict, strict=False).unexpected_keys
-
- print("✅ Missing keys (not in checkpoint):")
- for key in missing_keys:
- print(f" - {key}")
-
- print("\n⚠️ Unexpected keys (in checkpoint but not in model):")
- for key in unexpected_keys:
- print(f" - {key}")
- ##############################qingq check loaded weights######################################
-
-
-
- runner.model.eval()
- runner.logger.info(f'Load checkpoint from {args.checkpoint}')
-
-
- llm_name_or_path = 'Qwen/Qwen2.5-7B-Instruct'
- tokenizer = AutoTokenizer.from_pretrained(
- llm_name_or_path,
- trust_remote_code=True,
- encode_special_tokens=True)
-
- llm = runner.model.llm # dtype: float16
- llm.eval()
-
- LongNet_encoder = runner.model.LongNet_encoder.to(model_kwargs['torch_dtype']) # torch.bfloat16
- LongNet_encoder.cuda()
- LongNet_encoder.eval()
-
- projector = runner.model.projector.to(model_kwargs['torch_dtype'])
- projector.cuda()
- projector.eval()
-
- # Load compressor
- print("Loading compressor...")
- compressor = runner.model.compressor.to(model_kwargs['torch_dtype'])
- compressor.cuda()
- compressor.eval()
-
- # Load prefusion layers if they exist in the model
- prefusion_layers = None
- prefusion_layer_num = 0
- if hasattr(runner.model, 'prefusion_layers') and runner.model.prefusion_layer_num > 0:
- prefusion_layer_num = runner.model.prefusion_layer_num
- print(f"Found {prefusion_layer_num} prefusion layers. Loading them.")
- prefusion_layers = runner.model.prefusion_layers.to(model_kwargs['torch_dtype'])
- prefusion_layers.cuda()
- prefusion_layers.eval()
- else:
- print("No prefusion layers found in the model.")
-
-
- df_test_case = pd.read_csv(args.test_slide_csv)
-
- df_test_case['Output'] = df_test_case.apply(lambda x: '', axis=1)
- # Add a new column for generation time
- columns = ['ID','Slide','Tumor','Broad Category','Narrow Category','Question','A','B','C','D','Answer','Output', 'Generation Time (s)'] # <-- MODIFIED
- df_test_output = pd.DataFrame(columns=columns)
- generation_times = []
-
- if args.test_output_csv:
- output_dir = os.path.dirname(args.test_output_csv)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
- for i in range(df_test_case.shape[0]):
- # if i != 541:
- # continue
-
- print('*'*30)
- print('id: ', i, df_test_case.loc[i, 'Slide'])
- # only check the brca
- print('tumor type: ', args.tumor_type)
- if df_test_case.loc[i, 'Tumor'] != args.tumor_type: #'LUAD':
- continue
- print('tumor name: ', df_test_case.loc[i, 'Tumor'])
- tumor_name = df_test_case.loc[i, 'Tumor']
- case_name = df_test_case.loc[i, 'Slide']
-
- # test_image_file = "TCGA_patch_feat/" + df_test_case.loc[i, 'Tumor'] + "/" + case_name + ".csv"
- test_image_file = '/data/qingq/PathVLM/dataset/TCGA_features/conch_v1/' + str( tumor_name.lower() ) + '_224x224_b20_t15/pt_files/' + case_name + '.pt'
-
- # for some missing files, skip
- if not os.path.exists(test_image_file):
- with open("/data/qingq/PathVLM/baselines/github/SlideChat/outputs/missing_WSI_log.txt", "w") as f: # use "a" to append instead of overwrite
- f.write(test_image_file + "\n")
- f.close()
- continue
-
- if test_image_file.endswith('.csv'):
- image = pd.read_csv(test_image_file) # shape: [num_patches, 513]
- image = image.iloc[:, :512]
- total_rows = image.shape[0]
- sample_num = 38400
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_df = image.iloc[indices]
- image = sampled_df.iloc[:sample_num]
- image = image.to_numpy().reshape(1, image.shape[0], 512) # (1, N, 512)
- image = torch.from_numpy(image)
-
- # qingq modify, our feature format is .pt file
- elif test_image_file.endswith('.pt'):
- image = torch.load(test_image_file, map_location='cpu') # (N, 512)
- image = image.numpy()
- total_rows = image.shape[0]
- print('before sampling image shape', image.shape)
- sample_num = 10000 # 38400. original 38400 is out of memory for 45G
- if total_rows >= sample_num:
- indices = np.linspace(0, total_rows - 1, sample_num, dtype=int)
- sampled_image = image[indices]
- image = sampled_image[:sample_num]
-
- # Reshape and convert to tensor: (1, N, 512)
- image = torch.from_numpy(image.reshape(1, -1, 512)) # final shape: (1, N, 512)
- print('final image shape', image.shape)
-
- else:
- image = Image.open(test_image_file).convert('RGB')
-
- image = image.cuda() # shape (1, patch_num, 512)
- prompt_template = PROMPT_TEMPLATE.qwen_chat
- SYSTEM = ''
- question = df_test_case.loc[i, 'Question']
- options = []
- for opt in ['A', 'B', 'C', 'D']:
- if pd.notna(df_test_case.loc[i, opt]):
- options.append(f"{opt}. {df_test_case.loc[i, opt]}")
- options_str = '\n'.join(options)
-
- sample_input = f"{question}\n{options_str}"
- print('Input: ', sample_input)
-
- instruction = prompt_template.get('INSTRUCTION', '{input}')
- sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
- inputs = (SYSTEM + instruction).format(input=sample_input, round=1, **runner.cfg)
- chunk_encode = []
- for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
- if idx == 0:
- cur_encode = tokenizer.encode(chunk)
- else:
- cur_encode = tokenizer.encode(
- chunk, add_special_tokens=False)
- chunk_encode.append(cur_encode)
- assert len(chunk_encode) == 2
- input_ids = []
- for idx, cur_chunk_encode in enumerate(chunk_encode):
- input_ids.extend(cur_chunk_encode)
- if idx != len(chunk_encode) - 1:
- input_ids.append(IMAGE_TOKEN_INDEX)
- input_ids = torch.tensor(input_ids).cuda()
-
- image = image.to(projector.dtype)
- # 1. Pass through LongNet encoder
- feat_to_proj = image.permute(1, 0, 2)
- long_net_output = LongNet_encoder(src_tokens=None, token_embeddings=feat_to_proj)["encoder_out"]
- feat_to_proj = long_net_output.permute(1, 0, 2)
-
- # 2. Apply compressor and then projector
- compressed_features, _ = compressor(feat_to_proj)
- pixel_values = projector(compressed_features)
- print(f'Compressed visual features shape: {pixel_values.shape}')
-
- mm_inputs_kwargs = {}
-
- # 3. Apply pre-fusion layers if they exist
- if prefusion_layer_num > 0 and prefusion_layers is not None:
- print('Applying pre-fusion layers...')
- # Project the original (uncompressed) features for global context
- projected_global_image_features = projector(feat_to_proj)
- print(f'Global visual features shape: {projected_global_image_features.shape}')
-
- # Get text embeddings (ensure batch dimension is present)
- batched_input_ids = input_ids.unsqueeze(0)
- text_embeddings = llm.get_input_embeddings()(batched_input_ids.clamp(min=0))
-
- # Create padding mask for text
- padding_mask = (batched_input_ids <= 0)
-
- # Concatenate all features: [global_visual, compressed_visual, text]
- x = torch.cat([projected_global_image_features, pixel_values, text_embeddings], dim=1)
- print(f'Concatenated input shape for pre-fusion: {x.shape}')
-
- # Create the full attention mask for the combined sequence
- mask=torch.cat((torch.zeros((padding_mask.size(0),projected_global_image_features.size(1)+pixel_values.size(1)),device=padding_mask.device).bool(),padding_mask),dim=1)
-
- # Prepare attention mask for prefusion layers
- if getattr(llm, "_use_flash_attention_2", False) or \
- getattr(llm.config, "_attn_implementation", "") == "flash_attention_2":
- attention_mask = (~mask).int()
- else:
- attention_mask =_prepare_4d_causal_attention_mask(~mask, (x.size(0), x.size(1)), x, 0)
-
- position_ids = (~mask).int().long().cumsum(-1) - 1
- position_ids.masked_fill_((~mask).int() == 0, 1)
-
- # Apply pre-fusion layers sequentially
- for layer in prefusion_layers:
- x = layer(
- x,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=False,
- )[0]
-
- # Split the features back after the pre-fusion block
- text_seq_len = batched_input_ids.size(1)
- compressed_seq_len = pixel_values.size(1)
-
- fusion_text_features = x[:, -text_seq_len:, :]
- pixel_values = x[:, -(text_seq_len + compressed_seq_len):-text_seq_len, :]
-
- # Re-apply padding to text features (replicating model's logic)
- fusion_text_features = fusion_text_features * (~padding_mask).unsqueeze(-1).float() + text_embeddings * padding_mask.unsqueeze(-1).float()
-
- # Set the 'text_features' kwarg for the final input preparation
- mm_inputs_kwargs['text_features'] = fusion_text_features
-
- # mm_inputs_kwargs['pixel_values'] = pixel_values
- print(f'Shape of visual features after pre-fusion: {pixel_values.shape}')
- print(f'Shape of text features after pre-fusion: {fusion_text_features.shape}')
-
- # Final preparation of inputs for the LLM
- mm_inputs = prepare_inputs_labels_for_multimodal(
- llm=runner.model.llm,
- input_ids=input_ids.unsqueeze(0),
- pixel_values=pixel_values,
- **mm_inputs_kwargs) # Pass text_features if generated by pre-fusion
-
- max_new_tokens=500
- gen_config = GenerationConfig(
- max_new_tokens=max_new_tokens,
- do_sample=False,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
- )
- stop_words=[]
- stop_words += prompt_template.get('STOP_WORDS', [])
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(
- StopWordStoppingCriteria(tokenizer, word))
-
- # Start timer before generation
- start_time = time.time() # <-- ADDED
-
- generate_output = llm.generate(
- **mm_inputs,
- generation_config=gen_config,
- streamer=None,
- bos_token_id=tokenizer.bos_token_id,
- stopping_criteria=stop_criteria)
-
- # End timer and calculate duration
- torch.cuda.synchronize() # Ensure all operations are complete before measuring time
- end_time = time.time() # <-- ADDED
- generation_time = end_time - start_time # <-- ADDED
- generation_times.append(generation_time) # <-- ADDED
-
- generation_output = tokenizer.decode(generate_output[0])
- if generation_output.endswith('<|im_end|>'):
- generation_output = generation_output[:-10]
-
- print('Output: ', generation_output)
- print(f'Generation Time: {generation_time:.2f} seconds') # <-- ADDED
-
- add_row = {
- 'ID': df_test_case.loc[i, 'ID'],
- 'Slide': df_test_case.loc[i, 'Slide'],
- 'Tumor': df_test_case.loc[i, 'Tumor'],
- 'Broad Category': df_test_case.loc[i, 'Broad Category'],
- 'Narrow Category': df_test_case.loc[i, 'Narrow Category'],
- 'Question': question,
- 'A': df_test_case.loc[i, 'A'],
- 'B': df_test_case.loc[i, 'B'],
- 'C': df_test_case.loc[i, 'C'],
- 'D': df_test_case.loc[i, 'D'],
- 'Answer': df_test_case.loc[i, 'Answer'],
- 'Output': generation_output,
- 'Generation Time (s)': generation_time # <-- ADDED
- }
- df_test_output.loc[i] = add_row
- if args.test_output_csv:
- df_test_output.to_csv(args.test_output_csv, index=False) # <-- MODIFIED to prevent writing index column
-
- torch.cuda.empty_cache()
- if prefusion_layer_num > 0 and prefusion_layers is not None:
- del x # Clear the concatenated input tensor
- del attention_mask # Clear the attention mask
- del position_ids
-
-
- print('Test ok!')
- if generation_times:
- average_time = np.mean(generation_times)
- print(f"\nAverage Generation Time over {len(generation_times)} samples: {average_time:.4f} seconds")
-
-
- # check performance
-
-
- def slidechat_performance(output_file_path, eval_output_path):
-
- # Load the CSV
- df = pd.read_csv(output_file_path)
-
-
- # Clean ground-truth answers
- df['Answer_clean'] = df['Answer'].astype(str).str.strip().str.upper()
-
- # Extract the letter before the period in 'Output' (e.g., 'A. Luminal A' → 'A')
- df['Output_clean'] = df['Output'].astype(str).str.strip().str.extract(r'^([A-D])', expand=False).str.upper()
-
- # Compute exact match
- df['correct'] = df['Answer_clean'] == df['Output_clean']
-
- # Calculate accuracy
- accuracy = df['correct'].mean()
-
- # Calculate timing statistics <-- ADDED
- if 'Generation Time (s)' in df.columns:
- avg_time = df['Generation Time (s)'].mean()
- total_time = df['Generation Time (s)'].sum()
- time_stats = f"""Average Gen. Time: {avg_time:.4f} seconds
-Total Gen. Time : {total_time:.2f} seconds"""
- else:
- time_stats = "Generation time not available."
-
-
- # Print summary
- total = len(df)
- correct = df['correct'].sum()
- print(f"Exact Match Accuracy: {accuracy:.2%} ({correct}/{total})")
-
-
- # Build the result string <-- MODIFIED to include time
- result_text = f"""Evaluation Summary:
- ---------------------
- Total Samples : {total}
- Correct : {correct}
- Accuracy : {accuracy:.2%}
- {time_stats}
- """
-
- # Print to console
- print(output_file_path)
- print(result_text)
-
- # Save to txt file
- with open(eval_output_path, 'a+') as f:
- f.write(output_file_path)
- f.write('\n')
- f.write(result_text)
- f.write('\n')
-
- if args.test_output_csv:
- slidechat_performance(args.test_output_csv, args.eval_output_path)
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/tokenize_ftdp_datasets.py b/code/xtuner/tools/tokenize_ftdp_datasets.py
deleted file mode 100644
index 9327a91fef9f79c48d4c3e933e7f039e0a11f191..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/tokenize_ftdp_datasets.py
+++ /dev/null
@@ -1,433 +0,0 @@
-import argparse
-import json
-import os
-import os.path as osp
-from functools import partial
-from pathlib import Path
-from typing import Dict, List
-
-import numpy as np
-from mmengine import list_dir_or_file, track_progress_rich
-from transformers import AutoTokenizer
-
-SEPCIAL_TOKENS = [
- '<|plugin|>', '<|interpreter|>', '<|action_end|>', '<|action_start|>',
- '<|im_end|>', '<|im_start|>'
-]
-
-CHATML_LLAMAV13_32K_TOKEN_CFG = dict(
- role_cfg=dict(
- system=dict(
- begin=dict(
- with_name='<|im_start|>system name={name}\n',
- without_name='<|im_start|>system\n',
- name={
- 'interpreter': '<|interpreter|>',
- 'plugin': '<|plugin|>',
- }),
- end='<|im_end|>\n',
- loss=dict(
- meta=False,
- icl=False,
- current=False,
- prefix=False,
- )),
- user=dict(
- begin=dict(
- with_name='<|im_start|>user name={name}\n',
- without_name='<|im_start|>user\n',
- ),
- end='<|im_end|>\n',
- loss=dict(
- icl=False,
- current=False,
- prefix=False,
- )),
- assistant=dict(
- begin=dict(
- with_name='<|im_start|>assistant name={name}\n',
- without_name='<|im_start|>assistant\n',
- name={
- 'interpreter': '<|interpreter|>',
- 'plugin': '<|plugin|>',
- }),
- end='<|im_end|>\n',
- loss=dict(
- icl=True,
- current=True,
- prefix=False,
- end=True,
- )),
- environment=dict(
- begin=dict(
- with_name='<|im_start|>environment name={name}\n',
- without_name='<|im_start|>environment\n',
- name={
- 'interpreter': '<|interpreter|>',
- 'plugin': '<|plugin|>',
- }),
- end='<|im_end|>\n',
- loss=dict(
- icl=False,
- current=False,
- prefix=False,
- )),
- tool=dict(
- begin=dict(
- with_name='<|action_start|>{name}\n',
- name={
- 'interpreter': '<|interpreter|>',
- 'plugin': '<|plugin|>',
- }),
- end='<|action_end|>\n',
- belong='assistant',
- ),
- thought=dict(
- begin=dict(without_name=''),
- end='',
- belong='assistant',
- ),
- ),
- max_len=32 * 1024,
-)
-
-
-def chatml_format(
- processed_data,
- tokenizer,
- role_cfg,
- max_len=2048,
- encode_json=True,
-):
- """
- ```python
- dict(
- role='',
- content='',
- name='', -> Begin 扩增
- type='',
- )
- ```
- ```python
- dict(
- system=dict(
- begin=dict(
- with_name='system name={name}\n',
- without_name='system\n',
- name={
- 'interpreter': '',
- 'plugin': '',
- }),
- end='\n',
- loss=dict(
- meta=False,
- icl=False,
- current=False,
- prefix=False,
- )),
- user=dict(
- begin=dict(
- with_name='user name={name}\n',
- without_name='user\n',
- ),
- end='\n',
- loss=dict(
- icl=False,
- current=False,
- prefix=False,
- )),
- assistant=dict(
- begin=dict(
- with_name='assistant name={name}\n',
- without_name='assistant\n',
- name={
- 'interpreter': '',
- 'plugin': '',
- }),
- end='\n',
- loss=dict(
- icl=True,
- current=True,
- prefix=False,
- end=True,
- )),
- environment=dict(
- begin=dict(
- with_name='environment name={name}\n',
- without_name='environment\n',
- name={
- 'interpreter': '',
- 'plugin': '',
- }),
- end='\n',
- loss=dict(
- icl=False,
- current=False,
- prefix=False,
- )),
- tool=dict(
- begin=dict(
- with_name='{name}\n',
- name={
- 'interpreter': '',
- 'plugin': '',
- }),
- end='\n',
- belong='assistant',
- ),
- thought=dict(
- begin='',
- end='',
- belong='assistant',
- ),
- ```
- """
-
- def format_begin(role_cfg, message):
- name = message.get('name', None)
- if name is not None:
- begin = role_cfg['begin'].get('with_name', '')
- if name in role_cfg['begin'].get('name', {}):
- begin = begin.format(name=role_cfg['begin']['name'][name])
- else:
- begin = begin.format(name=name)
- else:
- begin = role_cfg['begin'].get('without_name', '')
- return begin
-
- def format_sub_role(messages: List[Dict], roles_cfg) -> List[Dict]:
- new_message = list()
- for message in messages:
- if message['role'] in [
- 'assistant', 'user', 'system', 'environment'
- ]:
- new_message.append(message)
- continue
- role_cfg = roles_cfg[message['role']]
- begin = format_begin(role_cfg, message)
- new_content = begin + message['content'] + role_cfg['end']
- if role_cfg.get('fallback_role'):
- new_message.append(
- dict(role=role_cfg['fallback_role'], content=new_content))
- elif role_cfg.get('belong'):
- if new_message[-1]['role'] != role_cfg.get('belong'):
- new_message.append(
- dict(role=role_cfg.get('belong'), content=new_content))
- else:
- new_message[-1]['content'] += new_content
- else:
- new_message.append(
- dict(role=message['role'], content=new_content))
-
- return new_message
-
- token_ids = []
- _processed_data = format_sub_role(processed_data, role_cfg)
-
- for dialog_item in _processed_data:
- role = dialog_item['role']
- content = dialog_item['content']
- # TODO: is strip necessary? or use lstrip? 避免开始有\n\n的情况
- # content = content.lstrip()
- begin = format_begin(role_cfg[role], dialog_item)
- end = role_cfg[role]['end']
- begin_token = tokenizer.encode(begin, add_special_tokens=False)
- if not role_cfg[role]['loss'].get('beigin', False):
- begin_token = [-token_id for token_id in begin_token]
- end_token = tokenizer.encode(
- role_cfg[role]['end'], add_special_tokens=False)
- # breakpoint()
- if not role_cfg[role]['loss'].get('end', False):
- end_token = [-token_id for token_id in end_token]
-
- content_token = tokenizer.encode(
- begin + content + end, add_special_tokens=False)
- content_token = content_token[len(begin_token):-len(end_token)]
-
- if dialog_item.get('loss', True):
- loss_cfg = role_cfg[role]['loss']
- else:
- loss_cfg = dict(icl=False, current=False, meta=False)
- if not loss_cfg[dialog_item.get('type', 'current')]:
- content_token = [-token_id for token_id in content_token]
-
- if begin == '':
- tokens = content_token
- else:
- tokens = begin_token + content_token
- if end != '':
- tokens = tokens + end_token
-
- token_ids += tokens
-
- token_ids = [tokenizer.bos_token_id] + token_ids
- token_ids = token_ids[:max_len]
- if encode_json:
- line = str.encode(json.dumps({'tokens': token_ids}) + '\n')
- return line, len(token_ids)
- return token_ids, len(token_ids)
-
-
-def write_bin_meta_bin(path, dataset_name, filename, samples):
- train_path = osp.join(path, f'train/cn/{dataset_name}')
- valid_path = osp.join(path, f'valid/cn/{dataset_name}')
- train_dir = Path(train_path)
- valid_dir = Path(valid_path)
- train_dir.mkdir(exist_ok=True, parents=True)
- valid_dir.mkdir(exist_ok=True, parents=True)
- train_f = open(train_dir.joinpath(f'{filename}.bin'), 'wb')
- valid_f_path = valid_dir.joinpath(f'{filename}.bin')
- valid_f = open(valid_f_path, 'wb')
- print(train_dir)
- print(valid_dir)
- train_tokens = 0
- valid_tokens = 0
- last_train_position = 0
- last_valid_position = 0
- train_samples = 0
- valid_samples = 0
- train_meta = []
- valid_meta = []
- for line, token_num in samples:
- train_tokens += token_num
- train_f.write(line)
- train_meta.append((last_train_position, token_num))
- last_train_position += len(line)
- train_samples += 1
- if (train_samples) % 100 == 0: # ?
- valid_tokens += token_num
- valid_f.write(line)
- valid_meta.append((last_valid_position, token_num))
- last_valid_position += len(line)
- valid_samples += 1
- train_f.close()
- valid_f.close()
- np.save(open(train_dir.joinpath(f'{filename}.bin.meta'), 'wb'), train_meta)
-
- # remove the length of `valid_samples` is less than 500
- # 500 is a magic number, you can change it to any number you want
- # the number must bigger the DP.
- if valid_samples > 500:
- np.save(
- open(valid_dir.joinpath(f'{filename}.bin.meta'), 'wb'), valid_meta)
- else:
- print(f'{valid_f_path} is removed because the number of',
- f'`valid_samples`({valid_samples}) is less than 500')
- os.remove(valid_f_path)
- return train_tokens, valid_tokens, train_samples, valid_samples
-
-
-def tokenize_and_save(tokenizer, processed_dir, tokenized_dir):
- tokenized_save_dir = osp.join(tokenized_dir, 'chatml_llamav13_32k')
- data_dir = processed_dir
- all_train_tokens = 0
- all_valid_tokens = 0
- all_train_samples = 0
- all_valid_samples = 0
-
- for filename in list_dir_or_file(data_dir, recursive=True, list_dir=False):
- file_path = os.path.join(data_dir, filename)
- if '/processed/' not in file_path:
- continue
- assert '.jsonl' in filename
-
- # dataset name such as char_x10_chat_format
- dataset_name = filename.split(os.sep)[0]
-
- # Hardcode here to skip tokenizing the file if it already exists
- # (Refactor the `write_bin_meta_bin`!).
- train_f = osp.join(tokenized_save_dir, 'train', 'cn', dataset_name,
- f'{osp.splitext(osp.basename(filename))[0]}.bin')
- if osp.isfile(train_f):
- print(f'{train_f} already exists, skip it')
- continue
-
- tokenize_fun = partial(
- chatml_format,
- tokenizer=tokenizer,
- **CHATML_LLAMAV13_32K_TOKEN_CFG)
- samples = []
- with open(file_path) as f:
- dataset = f.readlines()
- task_num = len(dataset)
- dataset = map(lambda x: (json.loads(x), ), dataset)
-
- for sample in track_progress_rich(
- tokenize_fun,
- dataset,
- nproc=32,
- task_num=task_num,
- chunksize=32,
- description=f'{os.path.basename(file_path)}...'):
- samples.append(sample)
-
- train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin( # noqa E501
- path=tokenized_save_dir,
- dataset_name=dataset_name,
- samples=samples,
- filename=osp.splitext(osp.basename(filename))[0])
- if train_tokens is None:
- print(f'{osp.splitext(osp.basename(filename))[0]} already '
- 'exists, skip it')
- continue
-
- print(f'train_tokens {train_tokens}', flush=True)
- print(f'train_samples {train_samples}')
- print(f'valid tokens {valid_tokens}')
- print(f'valid_samples {valid_samples}')
- all_train_tokens += train_tokens
- all_valid_tokens += valid_tokens
- all_train_samples += train_samples
- all_valid_samples += valid_samples
-
- print(f'all train tokens {all_train_tokens}')
- print(f'all train samples {all_train_samples}')
- print(f'all valid tokens {all_valid_tokens}')
- print(f'all valid samples {all_valid_samples}')
-
-
-def tokenizer_add_special_tokens(tokenizer):
- print(f'Before adding special tokens, Vocabulary Size: {len(tokenizer)}')
- for special_token in SEPCIAL_TOKENS:
- if special_token not in tokenizer.get_vocab():
- tokenizer.add_tokens([special_token], special_tokens=True)
- print(f'After adding special tokens, Vocabulary Size: {len(tokenizer)}')
-
-
-def save_new_tokenizer(tokenizer, save_dir):
- tokenizer.save_pretrained(save_dir)
- print(f'save new tokenizer to {save_dir}')
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--processed-dir', help='The folder to save untokenized data.')
- parser.add_argument(
- '--tokenized-dir', help='The folder to save tokenized data.')
- parser.add_argument(
- '--tokenizer-path', help='The path to the hf tokenizer.')
- parser.add_argument(
- '--tokenizer-w-special-tokens-save-dir',
- default=None,
- help='We have to add special tokens to the vocabulary of '
- 'the given tokenizer, and save the new tokenizer to this folder.')
- args = parser.parse_args()
- return args
-
-
-def main():
- args = parse_args()
- tokenizer = AutoTokenizer.from_pretrained(
- args.tokenizer_path, trust_remote_code=True, padding_side='right')
-
- ori_vocab_size = len(tokenizer)
- tokenizer_add_special_tokens(tokenizer)
- if len(tokenizer) != ori_vocab_size:
- save_new_tokenizer(tokenizer, args.tokenizer_w_special_tokens_save_dir)
-
- tokenize_and_save(tokenizer, args.processed_dir, args.tokenized_dir)
-
-
-if __name__ == '__main__':
- main()
diff --git a/code/xtuner/tools/train.py b/code/xtuner/tools/train.py
deleted file mode 100644
index 3d07dc903229fa1bfb6d64987364d5daeef9ca88..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/train.py
+++ /dev/null
@@ -1,553 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import argparse
-import json
-import logging
-import os
-import torch
-import os.path as osp
-from functools import partial
-from types import FunctionType
-
-import torch.distributed as dist
-
-from mmengine.config import Config, DictAction
-from mmengine.config.lazy import LazyObject
-from mmengine.logging import print_log
-from mmengine.registry import RUNNERS
-from mmengine.runner import Runner
-from mmengine.utils import digit_version
-from mmengine.hooks import Hook # ← added
-from mmengine.logging import print_log
-
-from peft import get_peft_model, prepare_model_for_kbit_training
-from transformers import TrainingArguments
-
-from xtuner.configs import cfgs_name_path
-from xtuner.dataset.collate_fns import default_collate_fn
-from xtuner.model.modules import dispatch_modules
-from xtuner.model.modules.dispatch import SUPPORT_FLASH2
-from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict
-from xtuner.registry import BUILDER, MAP_FUNC
-from xtuner.tools.utils import (auto_dtype_of_deepspeed_config,
- get_seed_from_checkpoint)
-
-torch.autograd.set_detect_anomaly(True)
-# torch.cuda.set_sync_debug_mode(1) # 可选,便于定位哪一步产生Inf/NaN
-
-def _unwrap(m):
- for a in ("module","model","_model","module_wrapped"):
- if hasattr(m, a) and hasattr(getattr(m,a), "named_parameters"):
- return getattr(m,a)
- return m
-
-# ---------------- 新增:训练时打印梯度&探针的 Hook ----------------
-class GradProbeHook(Hook):
- """每隔 n iter 打印 projector 梯度与 A/B/C 三个探针的 grad L2;并校验占位符被替换。"""
- priority = "LOW"
- def __init__(self, every_n_iters=20, only_rank0=True):
- self.n = every_n_iters
- self.only0 = only_rank0
-
- def _is0(self):
- try:
- return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank()==0
- except Exception:
- return True
-
- def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
- if self.n<=0 or (batch_idx+1)%self.n!=0: return
- if self.only0 and not self._is0(): return
-
- m = _unwrap(runner.model)
- if not hasattr(m, "projector"):
- print_log("[GradProbe] no projector found", logger=runner.logger)
- return
-
- # 1) projector 参数梯度统计
- n_tot = n_with = n_none = 0
- total_l2 = 0.0; max_abs = 0.0
- for name, p in m.projector.named_parameters():
- n_tot += 1
- g = p.grad
- if g is None:
- n_none += 1
- continue
- n_with += 1
- gf = g.detach().float()
- total_l2 += float(torch.linalg.vector_norm(gf).item())
- if gf.numel()>0:
- max_abs = max(max_abs, float(gf.abs().max().item()))
-
- # 2) 三个探针的 grad
- def gnorm(x):
- if x is None or x.grad is None: return None
- return float(torch.linalg.vector_norm(x.grad.detach().float()).item())
-
- gA = gnorm(getattr(m, "_tap_A_proj_out", None)) # projector 输出
- gB = gnorm(getattr(m, "_tap_B_to_prepare", None)) # 送 prepare 前最终视觉张量
- gC = gnorm(getattr(m, "_tap_C_inputs_embeds", None)) # 传入 LLM 的 inputs_embeds(需要在 compute_loss 前设置)
-
- # 3) 插入检查结果(来自 forward 的缓存)
- repl = getattr(m, "_tap_replacement_ok", None)
-
- msg = (f"=== [GradProbe] iter={batch_idx+1}\n"
- f"projector params: with_grad={n_with}/{n_tot} none={n_none} "
- f"sumL2={total_l2:.3e} max|g|={max_abs:.3e}\n"
- f"A=proj_out_grad_L2: {gA}\n"
- f"B=to_prepare_grad_L2: {gB}\n"
- f"C=inputs_embeds_grad_L2: {gC}\n"
- f"replace_ok={repl}")
- print_log(msg, logger=runner.logger)
-
-
-# ------------- helper to print trainable parameters -------------
-def _resolve_torch_model(obj):
- """Try to unwrap containers (DDP/Deepspeed) to an nn.Module."""
- if obj is None:
- return None
- m = obj
- # unwrap common wrappers
- for attr in ("module", "model", "_model", "module_wrapped"):
- if hasattr(m, attr) and hasattr(getattr(m, attr), "named_parameters"):
- m = getattr(m, attr)
- break
- return m if hasattr(m, "named_parameters") else None
-
-def print_trainable_parameters(model_like, *, show_names=True, max_names=200, logger='current'):
- try:
- model = _resolve_torch_model(model_like) or model_like
- if not hasattr(model, "named_parameters"):
- print_log("[Trainable] Model is not built yet (no named_parameters).", logger=logger)
- return
- total = 0
- trainable = 0
- trainable_lines = []
- for n, p in model.named_parameters():
- num = p.numel()
- total += num
- if p.requires_grad:
- trainable += num
- if show_names:
- trainable_lines.append(f"{n} shape={tuple(p.shape)} dtype={p.dtype}")
- pct = (trainable / total * 100) if total else 0.0
- print_log(f"[Trainable] {trainable:,} / {total:,} params ({pct:.2f}%)", logger=logger)
- if show_names:
- if len(trainable_lines) > max_names:
- head = "\n".join(trainable_lines[:max_names])
- print_log(head, logger=logger)
- print_log(f"... ({len(trainable_lines) - max_names} more trainable tensors not shown)", logger=logger)
- else:
- print_log("\n".join(trainable_lines) if trainable_lines else "(no trainable tensors)", logger=logger)
- except Exception as e:
- print_log(f"[Trainable] Failed to enumerate parameters: {e}", logger=logger)
-# ---------------------------------------------------------------
-
-# -------- Hook to print after MMEngine actually builds the model --------
-class PrintTrainableParamsHook(Hook):
- """MMEngine hook: print trainable params once the runner/model are ready."""
- priority = "LOW"
-
- def before_train(self, runner):
-
- # 1. 从 runner 获取 optimizer
- opt = runner.optim_wrapper.optimizer
-
- names_in_opt = set()
- # 2. 遍历 optimizer 的所有参数组
- for pg in opt.param_groups:
- # 3. 在每个组中遍历所有参数
- for p in pg['params']:
- # MMEngine 会为参数附加 _param_name 属性,便于识别
- if hasattr(p, '_param_name'):
- names_in_opt.add(p._param_name)
-
- # 打印出所有在 optimizer 中的参数名
- print("Parameters in optimizer:", names_in_opt)
-
- # 您的代码甚至更进一步,检查了 projector 中有哪些可训练参数没有被加入 optimizer
- miss = [n for n,_ in _unwrap(runner.model).projector.named_parameters() if n not in names_in_opt]
- print_log(f"[check] projector params NOT in optimizer: {miss}", logger=runner.logger)
-
- print("*******************************************************Trainable parameters (MMEngine)")
-
- print_trainable_parameters(runner.model)
-# -----------------------------------------------------------------------
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description='Train LLM')
- parser.add_argument('config', help='config file name or path.')
- parser.add_argument('--work-dir', help='the dir to save logs and models')
- parser.add_argument(
- '--deepspeed',
- type=str,
- default=None,
- help='the path to the .json file for deepspeed')
- parser.add_argument(
- '--resume',
- type=str,
- default=None,
- help='specify checkpoint path to be resumed from.')
- parser.add_argument(
- '--seed', type=int, default=None, help='Random seed for the training')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', '--local-rank', type=int, default=2)
- args = parser.parse_args()
-
- last_checkpoint_file = os.path.join(args.work_dir, 'last_checkpoint')
- # Ensure the directory exists
- os.makedirs(args.work_dir, exist_ok=True)
- if not os.path.exists(last_checkpoint_file):
- # Create the file if it doesn't exist
- open(last_checkpoint_file, 'w').close()
- print('-------------------------------------------------- Created empty last_checkpoint file ----------------------------------------------')
-
- with open(last_checkpoint_file, 'r', encoding='utf-8') as file:
- first_line = file.readline().strip()
-
- if first_line:
- print('--------------------------------------------------resume----------------------------------------------', first_line)
- args.resume = first_line
- else:
- print('-------------------------------------------------- NO resume----------------------------------------------')
- args.resume = None
-
- return args
-
-
-def register_function(cfg_dict):
- if isinstance(cfg_dict, dict):
- for key, value in dict.items(cfg_dict):
- if isinstance(value, FunctionType):
- value_str = str(value)
- if value_str not in MAP_FUNC:
- MAP_FUNC.register_module(module=value, name=value_str)
- cfg_dict[key] = value_str
- else:
- register_function(value)
- elif isinstance(cfg_dict, (list, tuple)):
- for value in cfg_dict:
- register_function(value)
-
-
-def check_cfg(cfg, args):
- if getattr(cfg, 'use_varlen_attn',
- False) and cfg.train_dataloader.batch_size > 1:
- raise NotImplementedError(
- f'If utilizing varlen attention, the batch size should be'
- f' set to 1, but got {cfg.train_dataloader.batch_size}')
-
- if getattr(cfg, 'use_varlen_attn', False):
- sequence_parallel = getattr(cfg, 'sequence_parallel', 1)
- max_length = getattr(cfg.train_dataloader.dataset, 'max_length', None)
- if max_length is not None:
- assert max_length % sequence_parallel == 0, \
- ('When using varlen attention, `max_length` should be evenly '
- 'divided by sequence parallel world size, but got '
- f'max_length = {max_length} and sequence_parallel = '
- f'{sequence_parallel}')
-
- if getattr(cfg, 'sequence_parallel_size', 1) > 1:
- assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use '
- 'sequence parallel.')
- attn_implementation = getattr(cfg.model.llm, 'attn_implementation',
- None)
- assert (attn_implementation is None or
- attn_implementation == 'flash_attention_2'), \
- ('If you want to use sequence parallel, please set '
- 'attn_implementation to `flash_attention_2` or do not '
- f'set this attribute. Got `{attn_implementation}` .')
-
- if getattr(cfg, 'use_varlen_attn', False):
- assert SUPPORT_FLASH2, ('`flash_attn` is required if you set '
- '`use_varlen_attn` to True.')
- attn_implementation = getattr(cfg.model.llm, 'attn_implementation',
- None)
- assert (attn_implementation is None or
- attn_implementation == 'flash_attention_2'), \
- ('If you want to set `use_varlen_attn` to True, please set'
- ' attn_implementation to `flash_attention_2` or do not '
- f'set this attribute. Got `{attn_implementation}` .')
-
- if args.deepspeed is None:
- assert getattr(cfg, 'sequence_parallel_size', 1) == 1, \
- ('Sequence parallel training without DeepSpeed lacks validation.'
- 'Please use DeepSpeed to optimize the training phase by '
- '`--deepspeed deepspeed_zero1 (deepspeed_zero2 or '
- 'deepspeed_zero3)`.')
-
-
-def main():
- print('*******************************************************Start Initialization')
-
- args = parse_args()
-
- # parse config
- if not osp.isfile(args.config):
- try:
- args.config = cfgs_name_path[args.config]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.config}')
-
- # load config
- cfg = Config.fromfile(args.config)
- # Force registry import so your constructor is registered
- from xtuner.engine.optimizer.muon_wrapper import MuonOptimWrapperConstructor
-
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # register FunctionType object in cfg to `MAP_FUNC` Registry and
- # change these FunctionType object to str
- register_function(cfg._cfg_dict)
-
- check_cfg(cfg, args)
-
-
- if cfg.get('framework', 'mmengine').lower() == 'huggingface':
- # set default training_args
- if cfg.get('training_args', None) is None:
- cfg.training_args = dict(type=TrainingArguments)
- if args.seed is not None:
- cfg.training_args.seed = args.seed
- # set work_dir
- if args.work_dir is not None:
- cfg.training_args.output_dir = args.work_dir
- elif cfg.training_args.get('output_dir', None) is None:
- cfg.training_args.output_dir = osp.join(
- './work_dirs',
- osp.splitext(osp.basename(args.config))[0])
- # enable deepspeed
- if args.deepspeed:
- if not osp.isfile(args.deepspeed):
- try:
- args.deepspeed = cfgs_name_path[args.deepspeed]
- except KeyError:
- raise FileNotFoundError(f'Cannot find {args.deepspeed}')
- cfg.training_args.deepspeed = args.deepspeed
- if cfg.training_args.get('deepspeed'):
- device_map = None
- else:
- # Data Parallel
- device_map = {
- '': int(os.environ.get('LOCAL_RANK', args.local_rank))
- }
- # build training_args
- training_args = BUILDER.build(cfg.training_args)
- # build model
- print('*******************************************************Build Model')
-
- with LoadWoInit():
- cfg.model.device_map = device_map
- traverse_dict(cfg.model)
- model = BUILDER.build(cfg.model)
- model.config.use_cache = False
- dispatch_modules(model)
- if cfg.get('lora', None):
- lora = BUILDER.build(cfg.lora)
- model = prepare_model_for_kbit_training(model)
- if lora.target_modules is None:
- modules = find_all_linear_names(model)
- lora.target_modules = modules
- model = get_peft_model(model, lora)
-
- # Print trainable params (HF branch)
- print("*******************************************************Trainable parameters (HF)")
- print_trainable_parameters(model)
-
- # build dataset
- print('*******************************************************Build Dataset')
- train_dataset = BUILDER.build(cfg.train_dataset)
- data_collator = partial(default_collate_fn, return_hf_format=True)
- # build trainer
- trainer = cfg.trainer(
- model=model,
- args=training_args,
- train_dataset=train_dataset,
- data_collator=data_collator)
- # training
- print('*******************************************************start training')
- trainer.train(resume_from_checkpoint=args.resume)
- trainer.save_state()
- trainer.save_model(output_dir=training_args.output_dir)
- else:
- if args.seed is not None and args.resume is None:
- # Use args.seed
- cfg.merge_from_dict(dict(randomness=dict(seed=args.seed)))
- print_log(
- f'Set the random seed to {args.seed}.',
- logger='current',
- level=logging.INFO)
- elif args.resume is not None: # False
- # Use resumed seed
- from mmengine.fileio import PetrelBackend, get_file_backend
-
- from xtuner.utils.fileio import patch_fileio
- backend = get_file_backend(args.resume)
- if isinstance(backend, PetrelBackend):
- with patch_fileio():
- resumed_seed = get_seed_from_checkpoint(args.resume)
- else:
- resumed_seed = get_seed_from_checkpoint(args.resume)
- cfg.merge_from_dict(dict(randomness=dict(seed=resumed_seed)))
- if args.seed is not None and args.seed != resumed_seed:
- print_log(
- (f'The value of random seed in resume checkpoint '
- f'"{args.resume}" is different from the value in '
- f'arguments. The resumed seed is {resumed_seed}, while '
- f'the input argument seed is {args.seed}. Using the '
- f'resumed seed {resumed_seed}.'),
- logger='current',
- level=logging.WARNING)
- else:
- print_log(
- f'Set the random seed to {resumed_seed}.',
- logger='current',
- level=logging.INFO)
-
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
- cfg.launcher = args.launcher
- # work_dir is determined in this priority:
- # CLI > segment in file > filename
- if args.work_dir is not None:
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None:
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
-
- if args.deepspeed:
- try:
- import deepspeed
- except ImportError:
- raise ImportError(
- 'deepspeed is not installed properly, please check.')
- if digit_version(deepspeed.__version__) < digit_version('0.12.3'):
- raise RuntimeError('Please upgrade your DeepSpeed version '
- 'by using the command pip install '
- '`deepspeed>=0.12.3`')
- optim_wrapper = cfg.optim_wrapper.type
- if optim_wrapper == 'DeepSpeedOptimWrapper':
- print_log(
- 'Deepspeed training is already enabled in your config.',
- logger='current',
- level=logging.WARNING)
- else:
- if not osp.isfile(args.deepspeed):
- try:
- args.deepspeed = cfgs_name_path[args.deepspeed]
- except KeyError:
- raise FileNotFoundError(
- f'Cannot find {args.deepspeed}')
- with open(args.deepspeed) as f:
- ds_cfg = json.load(f)
-
- ds_grad_accum = ds_cfg.get('gradient_accumulation_steps',
- 'auto')
- mm_grad_accum = cfg.optim_wrapper.get('accumulative_counts', 1)
- if ds_grad_accum != 'auto' and ds_grad_accum != mm_grad_accum:
- print_log(('Mismatch on gradient_accumulation_steps: '
- f'MMEngine {mm_grad_accum}, '
- f'Deepspeed {ds_grad_accum}. '
- f'Set to {mm_grad_accum}'),
- logger='current',
- level=logging.WARNING)
- grad_accum = mm_grad_accum
-
- ds_train_bs = ds_cfg.get('train_micro_batch_size_per_gpu',
- 'auto')
- mm_train_bs = cfg.train_dataloader.batch_size
- if ds_train_bs != 'auto' and ds_train_bs != mm_train_bs:
- print_log(
- ('Mismatch on train_micro_batch_size_per_gpu: '
- f'MMEngine {mm_train_bs}, Deepspeed {ds_train_bs}. '
- f'Set to {mm_train_bs}'),
- logger='current',
- level=logging.WARNING)
- train_bs = cfg.train_dataloader.batch_size
- print('train_bs', train_bs)
-
- ds_grad_clip = ds_cfg.get('gradient_clipping', 'auto')
- clip_grad = cfg.optim_wrapper.get('clip_grad', None)
- if clip_grad and clip_grad.get('max_norm', None) is not None:
- mm_max_norm = cfg.optim_wrapper.clip_grad.max_norm
- else:
- mm_max_norm = 1.0
- if ds_grad_clip != 'auto' and ds_grad_clip != mm_max_norm:
- print_log(
- ('Mismatch on gradient_clipping: '
- f'MMEngine {mm_max_norm}, Deepspeed {ds_grad_clip}. '
- f'Set to {mm_max_norm}'),
- logger='current',
- level=logging.WARNING)
- grad_clip = mm_max_norm
- ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg)
- exclude_frozen_parameters = True if digit_version(
- deepspeed.__version__) >= digit_version('0.10.1') else None
- strategy = dict(
- type=LazyObject('xtuner.engine', 'DeepSpeedStrategy'),
- config=ds_cfg,
- gradient_accumulation_steps=grad_accum,
- train_micro_batch_size_per_gpu=train_bs,
- gradient_clipping=grad_clip,
- exclude_frozen_parameters=exclude_frozen_parameters,
- sequence_parallel_size=getattr(cfg,
- 'sequence_parallel_size',
- 1))
- cfg.__setitem__('strategy', strategy)
- # 取原始 optim_wrapper,便于拿出可选项
- orig_ow = cfg.optim_wrapper if hasattr(cfg, 'optim_wrapper') else {}
-
- optim_wrapper = dict(
- type='DeepSpeedOptimWrapper',
- optimizer=orig_ow.get('optimizer'),
- # 关键:把 paramwise_cfg 传过来,custom_keys 才会生效
- paramwise_cfg=orig_ow.get('paramwise_cfg', None),
- )
-
- # 如果你在 cfg 里定义了自定义的 OptimWrapperConstructor,也一并传递
- if hasattr(cfg, 'optim_wrapper_constructor'):
- cfg.optim_wrapper_constructor = cfg.optim_wrapper_constructor
- cfg.__setitem__('optim_wrapper', optim_wrapper)
-
- cfg.runner_type = 'FlexibleRunner'
-
- # resume is determined in this priority: resume from > auto_resume
- if args.resume is not None:
- cfg.resume = True
- cfg.load_from = args.resume
-
- # build the runner from config
- if 'runner_type' not in cfg: # False
- # build the default runner
- runner = Runner.from_cfg(cfg)
- else:
- # build customized runner from the registry
- # if 'runner_type' is set in the cfg
- runner = RUNNERS.build(cfg)
-
- # Register hook to print once the model is actually built
- runner.register_hook(PrintTrainableParamsHook(), priority='LOW')
- # 注册 projector 梯度检查
- # runner.register_hook(GradProbeHook(every_n_iters=10, only_rank0=True), priority='LOW')
-
- # start training
- runner.train()
-
-
-if __name__ == '__main__':
- main()
\ No newline at end of file
diff --git a/code/xtuner/tools/utils.py b/code/xtuner/tools/utils.py
deleted file mode 100644
index df130c5f91a83ccd780fba247bbbc9b31aaa2c31..0000000000000000000000000000000000000000
--- a/code/xtuner/tools/utils.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os.path as osp
-import re
-import warnings
-
-import torch
-from transformers import PreTrainedTokenizerFast, StoppingCriteriaList
-from transformers.generation.streamers import BaseStreamer
-
-from xtuner.utils import StopWordStoppingCriteria
-
-
-def get_base_model(model):
- if hasattr(model, 'llm'):
- model = model.llm
- if 'PeftModel' in model.__class__.__name__:
- model = model.base_model.model
- return model
-
-
-def get_streamer(model):
- # TODO: deprecation, v0.3.0
- warnings.warn(
- ('`get_streamer` is deprecated and will be removed in v0.3.0, '
- "use `transformers`'s `TextStreamer` instead."), DeprecationWarning)
- if model.__class__.__name__ == 'InferenceEngine':
- model = model.module
- base_model = get_base_model(model)
- base_model_name = base_model.__class__.__name__.lower()
- is_internlm = 'internlm' in base_model_name
- is_qwen = 'qwen' in base_model_name
- is_baichuan = 'baichuan' in base_model_name
- is_chatglm = 'chatglm' in base_model_name
- no_space = is_internlm or is_qwen or is_baichuan or is_chatglm
- if no_space:
- return NoSpaceStreamer
- else:
- return DecodeOutputStreamer
-
-
-class DecodeOutputStreamer(BaseStreamer):
- """Default streamer for HuggingFace models."""
-
- def __init__(self, tokenizer, skip_prompt=True) -> None:
- super().__init__()
- # TODO: deprecation, v0.3.0
- warnings.warn(
- '`DecodeOutputStreamer` is deprecated and will be '
- 'removed in v0.3.0.', DeprecationWarning)
- self.tokenizer = tokenizer
- self.skip_prompt = skip_prompt
- self.gen_len = 0
- if isinstance(tokenizer, PreTrainedTokenizerFast):
- self.decode = self._decode_with_raw_id
- self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
- else:
- self.decode = self._decode_fallback
-
- def _decode_with_raw_id(self, value):
- """Convert token ids to tokens and decode."""
-
- tok = self.tokenizer._convert_id_to_token(value)
- if tok.startswith('▁'): # sentencepiece
- space = ' '
- tok = tok[1:]
- else:
- space = ''
- if res := self.hex_regex.match(tok):
- tok = chr(int(res.group(1), 16))
- if tok == '':
- tok = '\n'
- return space + tok
-
- def _decode_fallback(self, value):
- """Fallback decoder for non-fast tokenizer."""
-
- tok = self.tokenizer.decode(
- value,
- skip_special_tokens=False,
- clean_up_tokenization_spaces=False)
- return tok + ' '
-
- def put(self, value):
- """Callback function to decode token and output to stdout."""
-
- if self.gen_len == 0 and self.skip_prompt:
- pass
- else:
- tok = self.decode(value[0])
- print(tok, end='', flush=True)
-
- self.gen_len += 1
-
- def end(self):
- """Callback function to finish generation."""
-
- print('\n')
-
-
-class NoSpaceStreamer(DecodeOutputStreamer):
-
- def __init__(self, tokenizer, skip_prompt=True) -> None:
- BaseStreamer().__init__()
- # TODO: deprecation, v0.3.0
- warnings.warn(
- '`NoSpaceStreamer` is deprecated and will be '
- 'removed in v0.3.0.', DeprecationWarning)
- self.tokenizer = tokenizer
- self.skip_prompt = skip_prompt
- self.gen_len = 0
- self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
-
- def decode(self, value):
- tok = self.tokenizer.decode(value)
- if res := self.hex_regex.match(tok):
- tok = chr(int(res.group(1), 16))
- if tok == '' or tok == '\r':
- tok = '\n'
-
- return tok
-
-
-def get_stop_criteria(
- tokenizer,
- stop_words=[],
-):
- stop_criteria = StoppingCriteriaList()
- for word in stop_words:
- stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
- return stop_criteria
-
-
-def auto_dtype_of_deepspeed_config(ds_config):
- if ds_config.get('fp16') and not ds_config.get('bf16'):
- if ds_config.get('fp16').get('enabled') == 'auto':
- ds_config['fp16']['enabled'] = torch.cuda.is_available()
- elif not ds_config.get('fp16') and ds_config.get('bf16'):
- if ds_config.get('bf16').get('enabled') == 'auto':
- ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
- elif ds_config.get('fp16') and ds_config.get('bf16'):
- if ds_config.get('fp16').get('enabled') == 'auto':
- ds_config['fp16']['enabled'] = torch.cuda.is_available()
- if ds_config.get('bf16').get('enabled') == 'auto':
- ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
- if (ds_config['fp16']['enabled'] is True
- and ds_config['bf16']['enabled'] is True):
- ds_config['fp16']['enabled'] = False
- ds_config['bf16']['enabled'] = True
- return ds_config
-
-
-def is_cn_string(s):
- if re.search('[\u4e00-\u9fff]', s):
- return True
- return False
-
-
-def get_seed_from_checkpoint(pth_model):
- if osp.isfile(pth_model):
- checkpoint = torch.load(pth_model, map_location='cpu')
- elif osp.isdir(pth_model):
- try:
- from deepspeed.utils.zero_to_fp32 import get_model_state_files
- except ImportError:
- raise ImportError(
- 'The provided PTH model appears to be a DeepSpeed checkpoint. '
- 'However, DeepSpeed library is not detected in current '
- 'environment. This suggests that DeepSpeed may not be '
- 'installed or is incorrectly configured. Please verify your '
- 'setup.')
- filename = get_model_state_files(pth_model)[0]
- checkpoint = torch.load(filename, map_location='cpu',weights_only=False)
- else:
- raise FileNotFoundError(f'Cannot find {pth_model}')
- return checkpoint['meta']['seed']
diff --git a/code/xtuner/utils/__init__.py b/code/xtuner/utils/__init__.py
deleted file mode 100644
index 6663b32253528a8d02b61e1dec07326116ba6130..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from .constants import (DEFAULT_IMAGE_TOKEN, DEFAULT_PAD_TOKEN_INDEX,
- IGNORE_INDEX, IMAGE_TOKEN_INDEX)
-from .handle_moe_load_and_save import (SUPPORT_MODELS, get_origin_state_dict,
- load_state_dict_into_model)
-from .stop_criteria import StopWordStoppingCriteria
-from .templates import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
-
-__all__ = [
- 'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE',
- 'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria',
- 'IMAGE_TOKEN_INDEX', 'load_state_dict_into_model', 'get_origin_state_dict',
- 'SUPPORT_MODELS'
-]
diff --git a/code/xtuner/utils/__pycache__/__init__.cpython-311.pyc b/code/xtuner/utils/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 247e7265c5b611ec280e016ddb64fb78f5e332e8..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/__init__.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/constants.cpython-311.pyc b/code/xtuner/utils/__pycache__/constants.cpython-311.pyc
deleted file mode 100644
index db0a7b3f07585c33a512d271e62164a069cb9927..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/constants.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/fileio.cpython-311.pyc b/code/xtuner/utils/__pycache__/fileio.cpython-311.pyc
deleted file mode 100644
index 44e98fb42a4ccf073ca05ed2b9d3ea86a856bca4..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/fileio.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/handle_moe_load_and_save.cpython-311.pyc b/code/xtuner/utils/__pycache__/handle_moe_load_and_save.cpython-311.pyc
deleted file mode 100644
index 4b3800103eecb45e321f8b32dbeebfe215240458..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/handle_moe_load_and_save.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/stop_criteria.cpython-311.pyc b/code/xtuner/utils/__pycache__/stop_criteria.cpython-311.pyc
deleted file mode 100644
index be20a908d6e13b3143d240e77cc7bbea632bf962..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/stop_criteria.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/templates.cpython-311.pyc b/code/xtuner/utils/__pycache__/templates.cpython-311.pyc
deleted file mode 100644
index daea0fa3344a9809c4f2ec974852e0b48ce8b55f..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/templates.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/__pycache__/zero_to_any_dtype.cpython-311.pyc b/code/xtuner/utils/__pycache__/zero_to_any_dtype.cpython-311.pyc
deleted file mode 100644
index e3a8810d7ab21f777dc322fab2141168e8eff8c2..0000000000000000000000000000000000000000
Binary files a/code/xtuner/utils/__pycache__/zero_to_any_dtype.cpython-311.pyc and /dev/null differ
diff --git a/code/xtuner/utils/constants.py b/code/xtuner/utils/constants.py
deleted file mode 100644
index 2862c8ab50bb3f811795f5b8aea0d991505d6a41..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/constants.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-IGNORE_INDEX = -100
-DEFAULT_PAD_TOKEN_INDEX = 0
-IMAGE_TOKEN_INDEX = -200
-DEFAULT_IMAGE_TOKEN = ''
diff --git a/code/xtuner/utils/device.py b/code/xtuner/utils/device.py
deleted file mode 100644
index 8b1c87fcb8867cb5dec980099e469986c3a8707c..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/device.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-# This code is inspired by the torchtune.
-# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py
-
-import logging
-from typing import Optional
-
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-def is_torch_npu_available() -> bool:
- """Check the availability of NPU."""
- try:
- import torch_npu # noqa: F401
-
- return torch.npu.is_available()
- except ImportError:
- return False
-
-
-is_cuda_available = torch.cuda.is_available()
-is_npu_available = is_torch_npu_available()
-
-
-def get_device_name() -> str:
- """Function that gets the torch.device based on the current machine.
-
- This currently only supports CPU, CUDA, NPU.
-
- Returns:
- device
- """
- if is_cuda_available:
- device = "cuda"
- elif is_npu_available:
- device = "npu"
- else:
- device = "cpu"
- return device
-
-
-def get_device(device_name: Optional[str] = None) -> torch.device:
- """Function that takes an optional device string, verifies it's correct and
- available given the machine and distributed settings, and returns a
- :func:`~torch.device`. If device string is not provided, this function will
- infer the device based on the environment.
-
- If CUDA-like is available and being used, this function also sets the CUDA-like device.
-
- Args:
- device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu".
-
- Example:
- >>> device = get_device("cuda")
- >>> device
- device(type='cuda', index=0)
-
- Returns:
- torch.device: Device
- """
- if device_name is None:
- device_name = get_device_name()
- device = torch.device(device_name)
- return device
-
-
-def get_torch_device() -> any:
- """Return the corresponding torch attribute based on the device type
- string.
-
- Returns:
- module: The corresponding torch device namespace, or torch.cuda if not found.
- """
- device_name = get_device_name()
- try:
- return getattr(torch, device_name)
- except AttributeError:
- logger.warning(
- f"Device namespace '{device_name}' not found in torch, try to load torch.cuda."
- )
- return torch.cuda
diff --git a/code/xtuner/utils/dynamic_llava_dispatch.py b/code/xtuner/utils/dynamic_llava_dispatch.py
deleted file mode 100644
index 9f30aaff73b12191c7531f10e64da1a2860ba73f..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/dynamic_llava_dispatch.py
+++ /dev/null
@@ -1,1255 +0,0 @@
-"""
-Dynamic-LLaVA style pruning + custom losses integrated into XTuner LLaVA
-without breaking your LongNet flow. This is a *runtime* patch you can import
-and call after you construct your XTuner model + tokenizer.
-
-Phase 1 (safe-by-default): mask-only pruning
-- Learns token-keep policies for image / instruction / answer regions.
-- Applies them by augmenting the attention_mask (keys only) so dropped tokens
- are never attended to. Sequence length is unchanged → no KV/cache surgery.
-- Labels for dropped answer tokens are set to IGNORE_INDEX so they don't
- contribute to LM loss.
-
-Phase 2 (optional): gather-mode pruning (draft hooks provided)
-- Switch `mode="gather"` to actually remove tokens before the first
- transformer layer (for image) and from the attention keys/values for text.
-- You will likely iterate here to match dynamic-llava's varlen behavior.
-
-Usage
------
-from dynamic_xtuner_patch import integrate_dynamic_llava, DynamicCfg
-
-model = ... # your XTuner LLaVAModel instance
- tokenizer = ...
- dyn_cfg = DynamicCfg()
- integrate_dynamic_llava(model, tokenizer, dyn_cfg)
-
-Now train as usual. The model will compute the extra policy losses and merge
-its dynamic masks into attention.
-
----
-
-## NEW: Layer‑wise policies (closer to dynamic‑llava)
-
-The initial patch predicted one mask per forward pass. This section adds *per‑layer*
-policies that re‑evaluate keep/drop at each decoder layer (without changing
-sequence length). It stays FA2‑compatible by injecting an **additive mask per
-layer** only for that layer’s attention call.
-
-### How it works
-- We wrap **Qwen2DecoderLayer.forward** at runtime.
-- At each layer, we take the current `hidden_states` → run the shared MLP
- predictors over the region spans (image / last_instruct / answer) → produce
- a layer‑specific additive mask and merge it with the layer’s `attention_mask`.
-- We accumulate policy losses per layer into a per‑forward aggregator that the
- top‑level `compute_loss` returns (no KV surgery needed).
-
-### What’s added
-- `enable_layerwise_policies(model, cfg, every_n_layers=1, target_layers=None)`
- – turns on layer‑wise masking; by default all layers are wrapped.
-- A small aggregator `model._dyn_ctx` to share per‑batch metadata and collect
- loss terms across layers.
-- Updated `_build_dynamic_inputs` to store `image_token_mask` and inferred
- spans into `self._dyn_ctx` so layers can access them.
-
-
-
-### Wire it up in the hook
-At the end of `before_train` in your `DynamicLlavaPatchHook`, call:
-
-```python
-from xtuner.utils.dynamic_llava_dispatch import enable_layerwise_policies
-enable_layerwise_policies(runner.model, cfg, every_n_layers=1)
-```
-
-### Loss aggregation change
-Update `_compute_loss_dynamic` to sum any per-layer losses collected in
-`self._dyn_ctx['loss_terms']`:
-
-```python
-outputs = self.llm(**data_dyn)
-base_loss = outputs.loss
-
-# collect per-layer extras
-for v in getattr(self, '_dyn_ctx', {}).get('loss_terms', []):
- base_loss = base_loss + v
-# clear for next step
-if hasattr(self, '_dyn_ctx'):
- self._dyn_ctx['loss_terms'] = []
-
-return {'loss': base_loss, **{k:v for k,v in extra_losses.items()}}
-```
-
-This gives you dynamic-llava‑style **layer‑wise re‑scoring** while keeping
-sequence length fixed (mask‑only). It’s FA2‑friendly and generation‑safe. When
-you’re ready, we can add gather‑mode for the image span pre‑LLM as a next step.
-```
-Notes
------
-- This expects your batches to use the common SFT convention where answer
- tokens have labels != IGNORE_INDEX and instruction/system tokens have labels
- == IGNORE_INDEX. We'll detect spans from labels.
-- Vision pruning is applied on the *projected* image tokens the moment before
- they are inserted into the sequence (keeps LongNet intact).
-- You *must* enable LoRA when using 4-bit quantization; see the docstring in
- `enable_lora_if_needed` below (we call it during integration if possible).
-"""
-
-from __future__ import annotations
-
-import math
-import types
-from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, List
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from mmengine.hooks import Hook
-from mmengine.registry import HOOKS
-
-IGNORE_INDEX = -100
-
-# =======================
-# Utilities (STE ops)
-# =======================
-
-def _gumbel_noise_like(x: torch.Tensor) -> torch.Tensor:
- # Uniform(0, 1) -> Gumbel(0, 1)
- u = torch.rand_like(x)
- return -torch.log(-torch.log(u + 1e-9) + 1e-9)
-
-
-def ste_gumbel_sigmoid(logits: torch.Tensor, tau: float) -> torch.Tensor:
- """Straight-through hard Bernoulli via Gumbel-Sigmoid.
- Returns {0,1} mask with gradients wrt logits.
- """
- y = (logits + _gumbel_noise_like(logits)) / max(tau, 1e-5)
- y = torch.sigmoid(y)
- y_hard = (y > 0.5).float()
- return y_hard + (y - y.detach())
-
-
-def ste_topk(scores: torch.Tensor, k: int) -> torch.Tensor:
- """Straight-through hard top-k mask across last dim."""
- if k <= 0:
- return torch.zeros_like(scores)
- k = min(k, scores.shape[-1])
- topk = torch.topk(scores, k=k, dim=-1).indices
- hard = torch.zeros_like(scores)
- hard.scatter_(-1, topk, 1.0)
- # STE trick
- soft = scores / (scores.abs().sum(-1, keepdim=True) + 1e-6)
- return hard + (soft - soft.detach())
-
-
-# =======================
-# Config
-# =======================
-
-@dataclass
-class DynamicCfg:
- # General
- mode: str = "mask" # "mask" | "gather"
- use_text_policy: bool = True
- use_vision_policy: bool = True
-
- # Vision policy
- vision_hidden: int = 1024
- vision_tau: float = 1.0
- target_keep_ratio_image: float = 0.5 # 50% of image tokens kept (avg)
-
- # Text policy
- text_hidden: int = 1024
- text_tau: float = 1.0
- target_keep_ratio_instruct: float = 0.5
- target_keep_ratio_answer: float = 0.8
-
- # Loss weights
- w_img_kl: float = 0.2
- w_instr_kl: float = 0.2
- w_ans_kl: float = 0.2
- w_entropy: float = 0.01
-
- # Misc
- min_tokens_per_region: int = 8 # never drop below this in gather-mode
- debug_log_spans: bool = False
-
-
-# =======================
-# Policy heads
-# =======================
-
-class MLPPolicy(nn.Module):
- def __init__(self, in_features: int, hidden: int):
- super().__init__()
- self.net = nn.Sequential(
- nn.LayerNorm(in_features),
- nn.Linear(in_features, hidden),
- nn.GELU(),
- nn.Linear(hidden, 1),
- )
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # x: (B, T, C) -> logits: (B, T)
- return self.net(x).squeeze(-1)
-
-
-# =======================
-# Span helpers
-# =======================
-
-@torch.no_grad()
-def _infer_spans_from_labels(labels: torch.Tensor, image_token_mask: torch.Tensor) -> Dict[str, Tuple[int,int]]:
- """Infer spans using standard SFT convention.
- - `labels` shape (B, S) with IGNORE_INDEX for non-answer tokens.
- - `image_token_mask` shape (B, S) True for image tokens.
- Returns span dict per batch index (we handle B==1 here for simplicity of integration).
- """
- assert labels.dim() == 2 and labels.size(0) == 1, "This helper currently supports batch_size==1 for span inference."
- b = 0
- S = labels.size(1)
- # Image span = all contiguous True runs in image_token_mask; we take the full extent
- img_pos = torch.nonzero(image_token_mask[b], as_tuple=False).flatten()
- image_span = (int(img_pos.min().item()), int(img_pos.max().item()) + 1) if img_pos.numel() > 0 else (0, 0)
-
- # Answer span = last contiguous run where labels != IGNORE_INDEX
- ans_mask = (labels[b] != IGNORE_INDEX).bool()
- if ans_mask.any():
- idx = torch.nonzero(ans_mask, as_tuple=False).flatten()
- start, end = int(idx.min().item()), int(idx.max().item()) + 1
- answer_span = (start, end)
- else:
- answer_span = (S, S)
-
- # Instruction span: take the last run *before* answer where labels==IGNORE_INDEX and not image
- instr_mask = (~ans_mask) & (~image_token_mask[b])
- if answer_span[0] > 0:
- instr_mask[:answer_span[0]] = instr_mask[:answer_span[0]]
- instr_idx = torch.nonzero(instr_mask, as_tuple=False).flatten()
- if instr_idx.numel():
- instr_start, instr_end = int(instr_idx.min().item()), int(instr_idx.max().item()) + 1
- # constrain to before answer start
- instr_end = min(instr_end, answer_span[0])
- instruction_span = (instr_start, instr_end)
- else:
- instruction_span = (0, 0)
- else:
- instruction_span = (0, 0)
-
- return {"image": image_span, "last_instruct": instruction_span, "answer": answer_span}
-
-
-# =======================
-# Core integration
-# =======================
-
-def enable_lora_if_needed(model) -> None:
- """If model is 4-bit quantized, ensure LoRA is actually enabled.
- The user's class had self.use_llm_lora=None; we check and warn.
- """
- has_4bit = getattr(getattr(model, "llm", None), "is_loaded_in_4bit", False)
- if has_4bit:
- use_lora = getattr(model, "use_llm_lora", False)
- if not use_lora:
- print("[dynamic-xtuner] WARNING: LLM is 4-bit but LoRA is not enabled. Training may fail.")
-
-
-def integrate_dynamic_llava(model, tokenizer, cfg: Optional[DynamicCfg] = None):
- """Attach policy heads, losses, and forward hooks to an existing XTuner LLaVAModel instance.
- This is a *runtime patch*: no class source edits required.
- """
- if cfg is None:
- cfg = DynamicCfg()
-
- enable_lora_if_needed(model)
-
- # Infer sizes
- llm_hidden = int(model.llm.config.hidden_size)
-
- # Attach policy heads
- if cfg.use_vision_policy:
- model.vision_policy = MLPPolicy(llm_hidden, cfg.vision_hidden).to(model.llm.dtype).to(next(model.parameters()).device)
- else:
- model.vision_policy = None
-
- if cfg.use_text_policy:
- model.text_policy = MLPPolicy(llm_hidden, cfg.text_hidden).to(model.llm.dtype).to(next(model.parameters()).device)
- else:
- model.text_policy = None
-
- model.dynamic_cfg = cfg
-
- # Monkey-patch the compute_loss to inject policies & losses
- model._orig_compute_loss = model.compute_loss
- model.compute_loss = types.MethodType(_compute_loss_dynamic, model)
-
- # Also allow inference path to accept masks (no loss)
- model._orig_forward_tensor = getattr(model, "_forward", None)
- if model._orig_forward_tensor is not None:
- model._forward = types.MethodType(_forward_dynamic, model)
-
- print("[dynamic-xtuner] Integrated dynamic-llava (mask-mode by default).")
-
-
-# =======================
-# Patched methods
-# =======================
-
-def _build_dynamic_inputs(self, data: Dict) -> Tuple[Dict, Dict]:
- """Build inputs + region spans + policy masks, then merge masks into attention and labels.
- Returns (new_data, aux) where aux contains loss terms.
- Assumes `data` already contains text tensors prepared by dataset mapping
- and `data['pixel_values']` are *projected* image tokens in LLM hidden size.
- """
- cfg: DynamicCfg = self.dynamic_cfg
- device = data['input_ids'].device
-
- # 1) Work from tokenizer-produced text tensors
- input_ids: torch.Tensor = data['input_ids'] # (B, S)
- attention_mask: torch.Tensor = data['attention_mask'] # (B, S)
- labels: torch.Tensor = data['labels'] # (B, S)
-
- B, S = input_ids.shape
- assert B == 1, "This reference patch currently supports batch_size == 1 for dynamic spans."
-
- # 2) Turn input_ids into embeddings (so we can run text policy pre-LLM)
- tok_embed = self.llm.get_input_embeddings()
- text_embeds = tok_embed(input_ids) # (B, S, H)
-
- # 3) We expect the data dict is already extended by XTuner's prepare_* so
- # `data['pixel_values']` are *already in* the sequence via special handling.
- # To compute an image mask, detect image token positions from labels/ids.
- # Heuristic: XTuner sets labels for image tokens to IGNORE_INDEX and
- # uses special expansion. We mark as image when input_ids==pad and
- # attention_mask==1 but token embeddings equal to zero is unreliable.
- # Safer: XTuner stores an `image_bound` in mapping, but if absent we
- # infer via token id placeholder. We'll fall back to: any tokens whose
- # corresponding `labels==IGNORE_INDEX` *and* hidden state equals the
- # first row of `data['pixel_values']` is not directly accessible here.
- # Instead: we ask caller to supply `data['image_token_mask']` if available.
-
- image_token_mask = data.get('image_token_mask') # (B, S) bool or None
- if image_token_mask is None:
- # Fallback heuristic: treat consecutive IGNORE_INDEX region nearest to the beginning
- # where embeddings are unlikely to be pure text. We'll approximate by using
- # the first long IGNORE_INDEX block.
- approx = torch.zeros_like(attention_mask, dtype=torch.bool)
- # mark first run of >= 8 IGNORE_INDEX tokens as image
- run = (labels == IGNORE_INDEX).squeeze(0)
- if run.any():
- # find longest contiguous run
- idx = torch.where(run)[0]
- if idx.numel():
- diffs = torch.diff(idx, prepend=idx[:1])
- # crude: take first long run of 8+
- starts = idx[torch.where(diffs != 1)[0]] if (diffs != 1).any() else idx[:1]
- start = int(starts[0].item())
- end = start
- for j in range(start, S):
- if run[j].item():
- end = j + 1
- else:
- break
- if end - start >= 8:
- approx[0, start:end] = True
- image_token_mask = approx
-
- # 4) Infer spans from labels and image mask (B==1)
- spans = _infer_spans_from_labels(labels, image_token_mask)
- if cfg.debug_log_spans:
- print(f"[dynamic-xtuner] spans: {spans}")
-
- # 5) Build policy masks (vision, text)
- extra_mask = torch.zeros((B, S), dtype=text_embeds.dtype, device=device) # additive mask to keys
- losses: Dict[str, torch.Tensor] = {}
-
- # Vision
- if cfg.use_vision_policy and image_token_mask.any():
- img_s, img_e = spans["image"]
- H = text_embeds.shape[-1]
- # We can't directly access image embeddings here (they were inserted by the collator into inputs_embeds inside HF).
- # But in XTuner, `data['inputs_embeds']` is not precomputed at this point; instead HF will embed ids.
- # So we *approximate* by assuming image tokens are *already* represented by zeros in tok embedding,
- # which is not true. To keep this reference practical, we use the *projected* tokens we computed earlier
- # in LLaVAModel.forward just before calling this function and stash them into `data['__img_embeds__']`.
- img_embeds: Optional[torch.Tensor] = data.get("__img_embeds__") # (B, Nimg, H)
- if img_embeds is not None and img_embeds.numel() > 0:
- # Align lengths
- n_img = img_e - img_s
- img_embeds = img_embeds[:, :n_img, :]
- vision_logits = self.vision_policy(img_embeds.to(self.llm.dtype)) # (B, Nimg)
- if self.training:
- keep = ste_gumbel_sigmoid(vision_logits, tau=cfg.vision_tau)
- else:
- keep = (torch.sigmoid(vision_logits) > 0.5).float()
- # Targets: encourage avg keep ≈ target_keep_ratio_image
- avg_keep = keep.mean()
- target = torch.tensor(cfg.target_keep_ratio_image, device=device, dtype=avg_keep.dtype)
- losses['loss_img_kl'] = (avg_keep - target).pow(2) * cfg.w_img_kl
- # Entropy bonus to avoid collapse
- probs = torch.sigmoid(vision_logits)
- entropy = - (probs * (probs.clamp_min(1e-6).log()) + (1-probs) * ((1-probs).clamp_min(1e-6).log()))
- losses['loss_img_entropy'] = (entropy.mean() * cfg.w_entropy)
- # Build additive mask for keys: dropped -> -inf
- drop = (keep < 0.5).squeeze(0).bool()
- extra_mask[0, img_s:img_e][drop] = float('-inf')
- # In gather-mode we would actually remove tokens here; left as TODO.
- else:
- # No image embeds available; skip vision policy safely
- pass
-
- # Text (instruction + answer) – computed on input token embeddings
- if cfg.use_text_policy:
- for name, (w_kl, target_ratio) in {
- 'last_instruct': (cfg.w_instr_kl, cfg.target_keep_ratio_instruct),
- 'answer': (cfg.w_ans_kl, cfg.target_keep_ratio_answer),
- }.items():
- s, e = spans[name]
- if e > s:
- feats = text_embeds[:, s:e, :]
- logits = self.text_policy(feats.to(self.llm.dtype)) # (B, Tspan)
- if self.training:
- keep = ste_gumbel_sigmoid(logits, tau=cfg.text_tau)
- else:
- keep = (torch.sigmoid(logits) > 0.5).float()
- avg_keep = keep.mean()
- target = torch.tensor(target_ratio, device=device, dtype=avg_keep.dtype)
- losses[f'loss_{name}_kl'] = (avg_keep - target).pow(2) * w_kl
- probs = torch.sigmoid(logits)
- entropy = - (probs * (probs.clamp_min(1e-6).log()) + (1-probs) * ((1-probs).clamp_min(1e-6).log()))
- losses[f'loss_{name}_entropy'] = entropy.mean() * cfg.w_entropy
- # Keys mask
- drop = (keep < 0.5).squeeze(0).bool()
- extra_mask[0, s:e][drop] = float('-inf')
- # If dropping answer tokens, also ignore their labels
- if name == 'answer':
- labels = labels.clone()
- labels[0, s:e][drop] = IGNORE_INDEX
-
- # Merge constructed masks into data['attention_mask'] by additive convention used in HF
- # Transformers 4.40+ allows a float mask where 0.0 keeps, -inf masks.
- # Convert original attention_mask (0/1) -> (0.0 / -inf) additive form
- if attention_mask.dtype != extra_mask.dtype:
- attention_mask = attention_mask.to(extra_mask.dtype)
- additive = (1.0 - attention_mask) * float('-inf') # 1->0.0 keep, 0->-inf mask
- additive = additive + extra_mask # apply dynamic key masking
-
- new_data = dict(data)
- new_data['attention_mask'] = additive # HF attention paths will broadcast properly
- new_data['labels'] = labels
-
- return new_data, losses
-
-
-def _compute_loss_dynamic(self, data, data_samples=None):
- """Replacement for LLaVAModel.compute_loss that injects dynamic policies.
- Keeps LongNet+projector exactly as in the user's implementation.
- """
- # 1) First, run the usual multimodal preparation path to build `data` with pixel embeds
- # We hook right after LongNet+projector in the user's forward().
- # The user's forward() currently calls prepare_inputs_labels_for_multimodal *after*
- # producing `pixel_values`. We'll follow that ordering.
-
- # Ensure we're on the correct device (handled by caller's forward)
-
- # Stash projected image tokens (if any) for vision policy
- if 'pixel_values' in data and isinstance(data['pixel_values'], torch.Tensor):
- new_img = data['pixel_values']
- data['__img_embeds__'] = new_img.detach() if not self.training else new_img
-
- # Build dynamic masks and merge into inputs
- data_dyn, extra_losses = _build_dynamic_inputs(self, data)
-
- # 2) Forward through LLM with dynamic masks
- outputs = self.llm(**data_dyn)
- base_loss = outputs.loss
-
- # 3) Add policy losses
- total = base_loss
- for k, v in extra_losses.items():
- total = total + v
-
- # 4) Return standard dict
- return {'loss': total, **{k:v for k,v in extra_losses.items()}}
-
-
-def _forward_dynamic(self, data, data_samples=None):
- # In predict/tensor modes, we still honor dynamic masks if present but skip losses
- if 'pixel_values' in data and isinstance(data['pixel_values'], torch.Tensor):
- data['__img_embeds__'] = data['pixel_values']
- data_dyn, _ = _build_dynamic_inputs(self, data)
- return self.llm(**data_dyn)
-
-
-# =======================
-# (Optional) attention patch hook – placeholder
-# =======================
-
-def patch_qwen2_attention_for_dynamic(model) -> None:
- """Example hook point if you later want to pass `extra_mask` separately
- into attention forward. Not required for the mask-additive approach above.
- """
- import transformers
- from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2
-
- def wrap_forward(forward_fn):
- def new_forward(self_attn, *args, **kwargs):
- # Accept an optional `dynamic_additive_mask` and merge it inside
- extra = kwargs.pop('dynamic_additive_mask', None)
- out = forward_fn(self_attn, *args, **kwargs)
- return out
- return new_forward
-
- for mod in model.llm.modules():
- if isinstance(mod, transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2):
- mod.forward = types.MethodType(wrap_forward(mod.forward), mod)
-
- print("[dynamic-xtuner] Patched Qwen2FlashAttention2.forward (noop placeholder).")
-
-
-
-
-
-from __future__ import annotations
-
-import math
-import types
-from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, List
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-# =======================
-# Constants
-# =======================
-IGNORE_INDEX = -100
-try:
- # XTuner exposes this; fall back gracefully if not.
- from xtuner.utils import IMAGE_TOKEN_INDEX # type: ignore
-except Exception:
- IMAGE_TOKEN_INDEX = None # we'll error if not found at runtime when needed
-
-# =======================
-# Utilities (STE ops)
-# =======================
-
-def _gumbel_noise_like(x: torch.Tensor) -> torch.Tensor:
- u = torch.rand_like(x)
- return -torch.log(-torch.log(u + 1e-9) + 1e-9)
-
-
-def ste_gumbel_sigmoid(logits: torch.Tensor, tau: float) -> torch.Tensor:
- y = (logits + _gumbel_noise_like(logits)) / max(tau, 1e-5)
- y = torch.sigmoid(y)
- y_hard = (y > 0.5).float()
- return y_hard + (y - y.detach())
-
-
-def ste_topk(scores: torch.Tensor, k: int) -> torch.Tensor:
- if k <= 0:
- return torch.zeros_like(scores)
- k = min(k, scores.shape[-1])
- topk = torch.topk(scores, k=k, dim=-1).indices
- hard = torch.zeros_like(scores)
- hard.scatter_(-1, topk, 1.0)
- soft = scores / (scores.abs().sum(-1, keepdim=True) + 1e-6)
- return hard + (soft - soft.detach())
-
-
-# =======================
-# Predictor blocks (from user’s custom transformer)
-# =======================
-import collections.abc
-from itertools import repeat
-from functools import partial
-from torch.jit import Final
-
-_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
-_USE_FUSED_ATTN = 1
-_EXPORTABLE = False
-
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return tuple(x)
- return tuple(repeat(x, n))
- return parse
-
-to_2tuple = _ntuple(2)
-
-def use_fused_attn(experimental: bool = False) -> bool:
- if not _HAS_FUSED_ATTN or _EXPORTABLE:
- return False
- if experimental:
- return _USE_FUSED_ATTN > 1
- return _USE_FUSED_ATTN > 0
-
-def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-class DropPath(nn.Module):
- def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
- super().__init__()
- self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
- self.drop2 = nn.Dropout(drop_probs[1])
- def forward(self, x):
- x = self.fc1(x); x = self.act(x); x = self.drop1(x); x = self.norm(x); x = self.fc2(x); x = self.drop2(x); return x
-
-class Attention(nn.Module):
- fused_attn: Final[bool]
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm) -> None:
- super().__init__()
- assert dim % num_heads == 0
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.fused_attn = use_fused_attn()
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
- if self.fused_attn:
- x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-class CrossAttention(nn.Module):
- fused_attn: Final[bool]
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm) -> None:
- super().__init__()
- assert dim % num_heads == 0
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.fused_attn = use_fused_attn()
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
- self.k = nn.Linear(dim, dim, bias=qkv_bias)
- self.v = nn.Linear(dim, dim, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- def forward(self, qkv: Tuple[torch.Tensor]) -> torch.Tensor:
- assert len(qkv) == 3
- B_q, N_q, C_q = qkv[0].shape
- B_k, N_k, C_k = qkv[1].shape
- B_v, N_v, C_v = qkv[2].shape
- q = self.q(qkv[0]).reshape(B_q, N_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
- k = self.k(qkv[1]).reshape(B_k, N_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
- v = self.v(qkv[2]).reshape(B_v, N_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
- q, k = self.q_norm(q), self.k_norm(k)
- if self.fused_attn:
- x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
- else:
- q = q * self.scale
- attn = q @ k.transpose(-2, -1)
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
- x = attn @ v
- x = x.transpose(1, 2).reshape(B_q, N_q, C_v)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-class LayerScale(nn.Module):
- def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False) -> None:
- super().__init__()
- self.inplace = inplace
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
-
-class SelfTransformerEncoderBlock(nn.Module):
- def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp) -> None:
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer)
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- self.mlp = mlp_layer(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return x
-
-class CrossTransformerEncoderBlock(nn.Module):
- def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp) -> None:
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer)
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- self.mlp = mlp_layer(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def forward(self, qkv: Tuple[torch.Tensor]) -> torch.Tensor:
- assert len(qkv) == 3
- x = qkv[0] + self.drop_path1(self.ls1(self.attn((self.norm1(qkv[0]), self.norm1(qkv[1]), self.norm1(qkv[2])))))
- x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
- return (x, qkv[1], qkv[2])
-
-# ==== NEW: dynamic-llava style predictors ====
-class VisionPredictor(nn.Module):
- def __init__(self, input_dim=4096, d_model=512, nhead=8, dim_feedforward=2048, num_layers=2):
- super().__init__()
- self.input_dim = input_dim
- self.d_model = d_model
- self.nhead = nhead
- self.dim_feedforward = dim_feedforward
- self.num_layers = num_layers
- self.down_mlp = nn.Sequential(nn.LayerNorm(self.input_dim), nn.Linear(self.input_dim, self.d_model), nn.GELU())
- self.transformer = nn.Sequential(*[
- SelfTransformerEncoderBlock(dim=self.d_model, num_heads=self.nhead, mlp_ratio=self.dim_feedforward / self.d_model)
- for _ in range(self.num_layers)
- ])
- self.output_mlp = nn.Sequential(
- nn.Linear(self.d_model, self.d_model // 2), nn.GELU(),
- nn.Linear(self.d_model // 2, self.d_model // 4), nn.GELU(),
- nn.Linear(self.d_model // 4, 2),
- )
- def forward(self, x, image_policy): # x: (B,N,C), image_policy: (B,N,1)
- new_image_x = self.down_mlp(x)
- new_x = self.transformer(new_image_x * image_policy)
- B, N, C = new_x.size()
- local_x = new_x[:, :, : C // 2]
- global_x = (new_x[:, :, C // 2 :] * image_policy).sum(dim=1, keepdim=True) / torch.sum(image_policy, dim=1, keepdim=True).clamp_min(1.0)
- new_x = torch.cat([local_x, global_x.expand(B, N, C // 2)], dim=-1)
- predict = self.output_mlp(new_x) # (B,N,2)
- return predict
-
-class TextPredictor(nn.Module):
- def __init__(self, input_dim=4096, d_model=512, nhead=8, dim_feedforward=2048, num_layers=2):
- super().__init__()
- self.input_dim = input_dim
- self.d_model = d_model
- self.output_mlp = nn.Sequential(
- nn.LayerNorm(self.input_dim), nn.Linear(self.input_dim, self.d_model), nn.GELU(),
- nn.Linear(self.d_model, self.d_model // 2), nn.GELU(),
- nn.Linear(self.d_model // 2, self.d_model // 4), nn.GELU(),
- nn.Linear(self.d_model // 4, 2),
- )
- def forward(self, x): # (B,T,C) -> (B,T,2)
- return self.output_mlp(x)
-
-# =======================
-# Config
-# =======================
-
-@dataclass
-class DynamicCfg:
- # General
- mode: str = "mask" # "mask" | "gather"
- use_text_policy: bool = True
- use_vision_policy: bool = True
-
- # Predictor arch
- predictor: str = "transformer" # "mlp" | "transformer"
- vision_hidden: int = 1024 # for mlp
- text_hidden: int = 1024 # for mlp
- vision_depth: int = 1
- text_depth: int = 1
- num_heads: int = 8
- mlp_ratio: float = 4.0
- dropout: float = 0.0
-
- # Temperatures
- vision_tau: float = 1.0
- text_tau: float = 1.0
-
- # Target keep ratios
- target_keep_ratio_image: float = 0.5
- target_keep_ratio_instruct: float = 0.5
- target_keep_ratio_answer: float = 0.8
-
- # Loss weights
- w_img_kl: float = 0.2
- w_instr_kl: float = 0.2
- w_ans_kl: float = 0.2
- w_entropy: float = 0.01
-
- # Gather controls
- gather_image: bool = True
- gather_text: bool = False # advanced; keep False unless you patch KV/cache
- min_tokens_per_region: int = 8
-
- # Debug
- debug_log_spans: bool = False
-
-
-# =======================
-# Integrations
-# =======================
-
-def enable_lora_if_needed(model) -> None:
- has_4bit = getattr(getattr(model, "llm", None), "is_loaded_in_4bit", False)
- if has_4bit and not getattr(model, "use_llm_lora", False):
- print("[dynamic-xtuner] WARNING: LLM is 4-bit but LoRA is not enabled.")
-
-
-def _build_policies(model, cfg: DynamicCfg):
- H = int(model.llm.config.hidden_size)
- device = next(model.parameters()).device
- dtype = model.llm.dtype
-
- if getattr(cfg, 'predictor', 'dynamic_llava') == 'dynamic_llava':
- v_d_model = getattr(cfg, 'vision_d_model', 512)
- v_nhead = getattr(cfg, 'vision_nhead', 8)
- v_ffn = getattr(cfg, 'vision_dim_feedforward', 2048)
- v_layers = getattr(cfg, 'vision_layers', 2)
- t_d_model = getattr(cfg, 'text_d_model', 512)
- t_layers = getattr(cfg, 'text_layers', 2)
- vision = VisionPredictor(input_dim=H, d_model=v_d_model, nhead=v_nhead, dim_feedforward=v_ffn, num_layers=v_layers)
- text = TextPredictor(input_dim=H, d_model=t_d_model, num_layers=t_layers)
- else:
- # fallback (old MLP / transformer)
- class MLPPolicy(nn.Module):
- def __init__(self, in_features, hidden):
- super().__init__()
- self.net = nn.Sequential(nn.LayerNorm(in_features), nn.Linear(in_features, hidden), nn.GELU(), nn.Linear(hidden, 1))
- def forward(self, x):
- return self.net(x).squeeze(-1)
- vision = MLPPolicy(H, getattr(cfg, 'vision_hidden', 1024))
- text = MLPPolicy(H, getattr(cfg, 'text_hidden', 1024))
-
- model.vision_policy = vision.to(device=device, dtype=dtype) if cfg.use_vision_policy else None
- model.text_policy = text.to(device=device, dtype=dtype) if cfg.use_text_policy else None
-
-
-def integrate_dynamic_llava(model, tokenizer=None, cfg: Optional[DynamicCfg] = None):
- if cfg is None:
- cfg = DynamicCfg()
- enable_lora_if_needed(model)
- _build_policies(model, cfg)
- model.dynamic_cfg = cfg
-
- # Monkey-patch methods
- model.dynamic_prepare = types.MethodType(dynamic_prepare_inputs_labels, model)
- model._orig_compute_loss = model.compute_loss
- model.compute_loss = types.MethodType(_compute_loss_dynamic, model)
-
- # expose forward patch too (tensor/infer)
- if hasattr(model, "_forward"):
- model._orig_forward_tensor = model._forward
- model._forward = types.MethodType(_forward_dynamic, model)
-
- print("[dynamic-xtuner] Integrated dynamic-llava: prepare()+policies ready (mask-mode by default).")
-
-
-# =======================
-# Span helpers
-# =======================
-
-@torch.no_grad()
-def _infer_spans_from_labels(labels: torch.Tensor, image_token_mask: torch.Tensor) -> Dict[str, Tuple[int,int]]:
- assert labels.dim() == 2 and labels.size(0) == 1
- b = 0
- S = labels.size(1)
- img_pos = torch.nonzero(image_token_mask[b], as_tuple=False).flatten()
- image_span = (int(img_pos.min().item()), int(img_pos.max().item()) + 1) if img_pos.numel() > 0 else (0, 0)
-
- ans_mask = (labels[b] != IGNORE_INDEX).bool()
- if ans_mask.any():
- idx = torch.nonzero(ans_mask, as_tuple=False).flatten()
- answer_span = (int(idx.min().item()), int(idx.max().item()) + 1)
- else:
- answer_span = (S, S)
-
- instr_mask = (~ans_mask) & (~image_token_mask[b])
- if answer_span[0] > 0:
- instr_idx = torch.nonzero(instr_mask[:answer_span[0]], as_tuple=False).flatten()
- instruction_span = (int(instr_idx.min().item()), int(instr_idx.max().item()) + 1) if instr_idx.numel() else (0, 0)
- else:
- instruction_span = (0, 0)
-
- return {"image": image_span, "last_instruct": instruction_span, "answer": answer_span}
-
-
-# =======================
-# Dynamic multimodal prepare (+ optional gather)
-# =======================
-
-def dynamic_prepare_inputs_labels(self, data: Dict) -> Tuple[Dict, List[torch.Tensor]]:
- """Replacement for XTuner's prepare that also computes region spans, an image mask,
- and (optionally) gathers tokens before feeding the LLM.
-
- Expects `data` to contain: input_ids, attention_mask (optional), position_ids (optional),
- labels (optional), pixel_values (projected, [Nimg, Limg, H]).
-
- Returns: (prepared_data, prep_losses)
- Side effect: sets `self._dyn_ctx` with image_token_mask and region_spans for layer-wise policies.
- """
- cfg: DynamicCfg = self.dynamic_cfg
- llm = self.llm
-
- input_ids = data['input_ids']
- attention_mask = data.get('attention_mask', None)
- position_ids = data.get('position_ids', None)
- labels = data.get('labels', None)
- pixel_values = data.get('pixel_values', None) # [Nimg, Limg, H]
-
- # If no images, fall back to vanilla
- if pixel_values is None:
- return {k: data.get(k, None) for k in ['input_ids','position_ids','attention_mask','past_key_values','inputs_embeds','labels']}, []
-
- if IMAGE_TOKEN_INDEX is None:
- raise RuntimeError("IMAGE_TOKEN_INDEX not available; cannot locate placeholders.")
-
- _labels = labels
- _position_ids = position_ids
- _attention_mask = attention_mask
-
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
- else:
- attention_mask = attention_mask.bool()
- if position_ids is None:
- position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
- if labels is None:
- labels = torch.full_like(input_ids, IGNORE_INDEX)
-
- # remove padding by mask
- input_ids_list = [ids[mask] for ids, mask in zip(input_ids, attention_mask)]
- labels_list = [lab[mask] for lab, mask in zip(labels, attention_mask)]
-
- new_inputs_embeds_unpad: List[torch.Tensor] = []
- new_labels_unpad: List[torch.Tensor] = []
- image_token_mask_unpad: List[torch.Tensor] = []
- prep_losses: List[torch.Tensor] = []
-
- cur_image_idx = 0
- tok_embed = llm.get_input_embeddings()
-
- for b_idx, cur_input_ids in enumerate(input_ids_list):
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
- cur_labels = labels_list[b_idx]
-
- if num_images == 0:
- cur_inputs_embeds_1 = tok_embed(cur_input_ids)
- cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, pixel_values[0:0]], dim=0)
- new_inputs_embeds_unpad.append(cur_inputs_embeds)
- new_labels_unpad.append(cur_labels)
- image_token_mask_unpad.append(torch.zeros(cur_inputs_embeds.size(0), dtype=torch.bool, device=cur_inputs_embeds.device))
- continue
-
- image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
- cur_input_ids_noim = []
- cur_labels_noim = []
- for i in range(len(image_token_indices) - 1):
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
- cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
-
- split_sizes = [x.shape[0] for x in cur_labels_noim]
- cur_inputs_embeds_noim = torch.split(tok_embed(torch.cat(cur_input_ids_noim)), split_sizes, dim=0)
-
- cur_new_inputs_embeds = []
- cur_new_labels = []
- cur_img_mask = []
-
- for i in range(num_images + 1):
- # text segment
- seg = cur_inputs_embeds_noim[i]
- lab = cur_labels_noim[i]
- cur_new_inputs_embeds.append(seg)
- cur_new_labels.append(lab)
- cur_img_mask.append(torch.zeros(seg.shape[0], dtype=torch.bool, device=seg.device))
-
- if i < num_images:
- # image block (projected tokens)
- img_emb = pixel_values[cur_image_idx] # (Limg, H)
- cur_image_idx += 1
-
- # (Optional) pre-LLM gather on image tokens
- if cfg.use_vision_policy and cfg.mode == 'gather' and cfg.gather_image:
- # VisionPredictor expects (B,N,C) and image_policy (B,N,1)
- logits2 = self.vision_policy(img_emb.unsqueeze(0).to(llm.dtype), torch.ones(1, img_emb.size(0), 1, device=img_emb.device, dtype=llm.dtype)) # (1,L,2)
- probs = torch.softmax(logits2, dim=-1)[..., 1] # keep prob (1,L)
- keep_ratio = cfg.target_keep_ratio_image
- k = max(cfg.min_tokens_per_region, int(math.ceil(keep_ratio * img_emb.size(0))))
- keep = ste_topk(probs, k=k) if self.training else (probs > 0.5).float()
- keep_idx = torch.nonzero(keep[0] > 0.5, as_tuple=False).flatten()
- if keep_idx.numel() == 0:
- keep_idx = torch.arange(0, min(cfg.min_tokens_per_region, img_emb.size(0)), device=img_emb.device)
- img_kept = img_emb[keep_idx]
- # losses for ratio + entropy
- avg_keep = keep.mean()
- target = torch.tensor(keep_ratio, device=avg_keep.device, dtype=avg_keep.dtype)
- p = probs.clamp(1e-6, 1-1e-6)
- entropy = - (p * p.log() + (1-p) * (1-p).log())
- prep_losses.append((avg_keep - target).pow(2) * cfg.w_img_kl)
- prep_losses.append(entropy.mean() * cfg.w_entropy)
- else:
- img_kept = img_emb
- else:
- img_kept = img_emb
-
- cur_new_inputs_embeds.append(img_kept)
- cur_new_labels.append(torch.full((img_kept.shape[0],), IGNORE_INDEX, device=lab.device, dtype=lab.dtype))
- cur_img_mask.append(torch.ones(img_kept.shape[0], dtype=torch.bool, device=lab.device))
-
- cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds, dim=0)
- cur_new_labels = torch.cat(cur_new_labels, dim=0)
- cur_img_mask = torch.cat(cur_img_mask, dim=0)
-
- new_inputs_embeds_unpad.append(cur_new_inputs_embeds)
- new_labels_unpad.append(cur_new_labels)
- image_token_mask_unpad.append(cur_img_mask)
-
- # Pad back to batch max
- max_len = max(x.shape[0] for x in new_inputs_embeds_unpad)
- B = len(new_inputs_embeds_unpad)
-
- new_inputs_embeds_padded = []
- new_labels_padded = torch.full((B, max_len), IGNORE_INDEX, dtype=new_labels_unpad[0].dtype, device=new_labels_unpad[0].device)
- attn_mask = torch.zeros((B, max_len), dtype=torch.bool, device=attention_mask.device)
- pos_ids = torch.zeros((B, max_len), dtype=position_ids.dtype, device=position_ids.device)
- image_token_mask = torch.zeros((B, max_len), dtype=torch.bool, device=attention_mask.device)
-
- for i, (emb, lab, imask) in enumerate(zip(new_inputs_embeds_unpad, new_labels_unpad, image_token_mask_unpad)):
- L = emb.shape[0]
- pad = torch.zeros((max_len - L, emb.shape[1]), dtype=emb.dtype, device=emb.device)
- new_inputs_embeds_padded.append(torch.cat((emb, pad), dim=0))
- new_labels_padded[i, :L] = lab
- attn_mask[i, :L] = True
- pos_ids[i, :L] = torch.arange(0, L, dtype=pos_ids.dtype, device=pos_ids.device)
- image_token_mask[i, :L] = imask
-
- new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
-
- # Infer spans; store ctx for layer-wise policies
- spans = _infer_spans_from_labels(new_labels_padded, image_token_mask)
- if cfg.debug_log_spans:
- print(f"[dynamic-xtuner] spans: {spans}")
- # Keep per-forward ctx
- self._dyn_ctx = dict(image_token_mask=image_token_mask, region_spans=spans, loss_terms=[])
-
- # Return with XTuner-compatible keys
- out = {
- 'input_ids': None,
- 'position_ids': None if _position_ids is None else pos_ids,
- 'attention_mask': None if _attention_mask is None else attn_mask.to(dtype=_attention_mask.dtype),
- 'past_key_values': data.get('past_key_values', None),
- 'inputs_embeds': new_inputs_embeds,
- 'labels': None if _labels is None else new_labels_padded,
- # extra metadata (won't be consumed by HF)
- 'image_token_mask': image_token_mask,
- 'region_spans': spans,
- }
- return out, prep_losses
-
-
-# =======================
-# Compute loss / forward patches
-# =======================
-
-def _compute_loss_dynamic(self, data, data_samples=None):
- # assume LongNet + projector already ran in self.forward before calling us
- if 'pixel_values' in data and isinstance(data['pixel_values'], torch.Tensor):
- data['__img_embeds__'] = data['pixel_values']
-
- prepared, prep_losses = self.dynamic_prepare(data)
-
- outputs = self.llm(**prepared)
- total = outputs.loss
-
- for v in getattr(self, '_dyn_ctx', {}).get('loss_terms', []):
- total = total + v
- if hasattr(self, '_dyn_ctx'):
- self._dyn_ctx['loss_terms'] = []
-
- for v in prep_losses:
- total = total + v
-
- return {'loss': total}
-
-
-def _forward_dynamic(self, data, data_samples=None):
- if 'pixel_values' in data and isinstance(data['pixel_values'], torch.Tensor):
- data['__img_embeds__'] = data['pixel_values']
- prepared, _ = self.dynamic_prepare(data)
- return self.llm(**prepared)
-
-
-# =======================
-# Layer-wise masking (mask-only per layer)
-# =======================
-import transformers
-from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
-
-def _ensure_dyn_ctx(self):
- if not hasattr(self, "_dyn_ctx") or not isinstance(self._dyn_ctx, dict):
- self._dyn_ctx = {}
- if "loss_terms" not in self._dyn_ctx:
- self._dyn_ctx["loss_terms"] = []
-
-def _merge_additive_mask(base_mask, extra_add):
- if extra_add is None:
- return base_mask
- if base_mask is None:
- return extra_add
- if base_mask.dtype == torch.bool:
- base_add = (1.0 - base_mask.float()) * float('-inf')
- else:
- base_add = base_mask
- return base_add + extra_add
-
-
-def enable_layerwise_policies(model, cfg: DynamicCfg, every_n_layers: int = 1, target_layers: Optional[List[int]] = None):
- assert hasattr(model, 'text_policy') or hasattr(model, 'vision_policy'), "Call integrate_dynamic_llava() first."
- model_ref = model
- def wrap_layer(layer: Qwen2DecoderLayer, layer_idx: int):
- if getattr(layer, '_dyn_wrapped', False):
- return
- layer._dyn_wrapped = True
- orig_forward = layer.forward
- def forward_with_dynamic(self_layer, hidden_states, **kwargs):
- model = model_ref
- cfg = model.dynamic_cfg
- _ensure_dyn_ctx(model)
- base_mask = kwargs.get('attention_mask', None)
- ctx = getattr(model, '_dyn_ctx', {})
- spans = ctx.get('region_spans')
- img_mask = ctx.get('image_token_mask')
- extra = None
- if spans is not None:
- B, S, H = hidden_states.shape
- assert B == 1, "Layer-wise patch assumes batch_size==1"
- extra = torch.zeros((B, S), dtype=hidden_states.dtype, device=hidden_states.device)
- if cfg.use_vision_policy and img_mask is not None and img_mask.any():
- s, e = spans["image"]
- feats = hidden_states[:, s:e, :]
- logits = model.vision_policy(feats.to(model.llm.dtype))
- keep = ste_gumbel_sigmoid(logits, tau=cfg.vision_tau) if model.training else (torch.sigmoid(logits) > 0.5).float()
- drop = (keep < 0.5).squeeze(0).bool()
- extra[0, s:e][drop] = float('-inf')
- avg_keep = keep.mean(); target = torch.tensor(cfg.target_keep_ratio_image, device=avg_keep.device, dtype=avg_keep.dtype)
- probs = torch.sigmoid(logits)
- entropy = - (probs * (probs.clamp_min(1e-6).log()) + (1-probs) * ((1-probs).clamp_min(1e-6).log()))
- model._dyn_ctx['loss_terms'].append((avg_keep - target).pow(2) * cfg.w_img_kl)
- model._dyn_ctx['loss_terms'].append(entropy.mean() * cfg.w_entropy)
- if cfg.use_text_policy:
- for name, (w_kl, target_ratio) in {
- 'last_instruct': (cfg.w_instr_kl, cfg.target_keep_ratio_instruct),
- 'answer': (cfg.w_ans_kl, cfg.target_keep_ratio_answer),
- }.items():
- s, e = spans[name]
- if e > s:
- feats = hidden_states[:, s:e, :]
- logits = model.text_policy(feats.to(model.llm.dtype))
- keep = ste_gumbel_sigmoid(logits, tau=cfg.text_tau) if model.training else (torch.sigmoid(logits) > 0.5).float()
- drop = (keep < 0.5).squeeze(0).bool()
- extra[0, s:e][drop] = float('-inf')
- avg_keep = keep.mean(); target = torch.tensor(target_ratio, device=avg_keep.device, dtype=avg_keep.dtype)
- probs = torch.sigmoid(logits)
- entropy = - (probs * (probs.clamp_min(1e-6).log()) + (1-probs) * ((1-probs).clamp_min(1e-6).log()))
- model._dyn_ctx['loss_terms'].append((avg_keep - target).pow(2) * w_kl)
- model._dyn_ctx['loss_terms'].append(entropy.mean() * cfg.w_entropy)
- if extra is not None:
- kwargs['attention_mask'] = _merge_additive_mask(base_mask, extra)
- return orig_forward(hidden_states, **kwargs)
- layer.forward = types.MethodType(forward_with_dynamic, layer)
-
- all_layers = [m for m in model.llm.modules() if isinstance(m, Qwen2DecoderLayer)]
- if target_layers is None:
- target_layers = list(range(len(all_layers)))[::every_n_layers]
- for idx in target_layers:
- wrap_layer(all_layers[idx], idx)
- print(f"[dynamic-xtuner] Layer-wise policies enabled on layers {target_layers}.")
-
-
-# =======================
-# state_dict patch (save policies)
-# =======================
-from collections import OrderedDict
-
-def patch_state_dict_to_save_policies(model: nn.Module) -> None:
- if getattr(model, "_state_dict_policies_patched", False):
- return
- model._orig_state_dict_save = model.state_dict
- def state_dict_with_policies(self, *args, **kwargs):
- base = self._orig_state_dict_save(*args, **kwargs)
- out = OrderedDict(base)
- vp = getattr(self, "vision_policy", None)
- if isinstance(vp, nn.Module):
- for k, v in vp.state_dict(*args, **kwargs).items():
- out[f"vision_policy.{k}"] = v
- tp = getattr(self, "text_policy", None)
- if isinstance(tp, nn.Module):
- for k, v in tp.state_dict(*args, **kwargs).items():
- out[f"text_policy.{k}"] = v
- return out
- model.state_dict = types.MethodType(state_dict_with_policies, model)
- model._state_dict_policies_patched = True
-
-
-# =======================
-# Optional: attention patch placeholder (not needed for mask/gather here)
-# =======================
-
-def patch_qwen2_attention_for_dynamic(model) -> None:
- print("[dynamic-xtuner] attention patch placeholder (not required for current modes)")
-
-
-@HOOKS.register_module()
-class DynamicLlavaPatchHook(Hook):
- def __init__(self, dyn_cfg=None, layers_every=1):
- self.dyn_cfg = dyn_cfg or {}
- self.layers_every = layers_every
-
- def before_train(self, runner):
- from xtuner.utils.dynamic_llava_dispatch import (
- integrate_dynamic_llava, DynamicCfg, enable_layerwise_policies,
- patch_state_dict_to_save_policies
- )
- cfg = DynamicCfg(**self.dyn_cfg)
- integrate_dynamic_llava(runner.model, None, cfg)
- enable_layerwise_policies(runner.model, cfg, every_n_layers=self.layers_every)
- patch_state_dict_to_save_policies(runner.model)
- runner.logger.info('[dynamic-xtuner] dynamic prepare + layer-wise policies + policy saving enabled')
\ No newline at end of file
diff --git a/code/xtuner/utils/fileio.py b/code/xtuner/utils/fileio.py
deleted file mode 100644
index 922146e584313f35b5cdcd76b3908ed0e4f7ce11..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/fileio.py
+++ /dev/null
@@ -1,345 +0,0 @@
-import io
-from contextlib import contextmanager
-
-import mmengine.fileio as fileio
-from mmengine.fileio import LocalBackend, PetrelBackend, get_file_backend
-
-
-def patch_func(module, fn_name_to_wrap):
- backup = getattr(patch_func, '_backup', [])
- fn_to_wrap = getattr(module, fn_name_to_wrap)
-
- def wrap(fn_new):
- setattr(module, fn_name_to_wrap, fn_new)
- backup.append((module, fn_name_to_wrap, fn_to_wrap))
- setattr(fn_new, '_fallback', fn_to_wrap)
- setattr(patch_func, '_backup', backup)
- return fn_new
-
- return wrap
-
-
-@contextmanager
-def patch_fileio(global_vars=None):
- if getattr(patch_fileio, '_patched', False):
- # Only patch once, avoid error caused by patch nestly.
- yield
- return
- import builtins
-
- @patch_func(builtins, 'open')
- def open(file, mode='r', *args, **kwargs):
- backend = get_file_backend(file)
- if isinstance(backend, LocalBackend):
- return open._fallback(file, mode, *args, **kwargs)
- if 'b' in mode:
- return io.BytesIO(backend.get(file, *args, **kwargs))
- else:
- return io.StringIO(backend.get_text(file, *args, **kwargs))
-
- if global_vars is not None and 'open' in global_vars:
- bak_open = global_vars['open']
- global_vars['open'] = builtins.open
-
- import os
-
- @patch_func(os.path, 'join')
- def join(a, *paths):
- backend = get_file_backend(
- a.decode('utf-8') if isinstance(a, bytes) else a)
- if isinstance(backend, LocalBackend):
- return join._fallback(a, *paths)
- paths = [item.lstrip('./') for item in paths if len(item) > 0]
- return backend.join_path(a, *paths)
-
- @patch_func(os.path, 'isdir')
- def isdir(path):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return isdir._fallback(path)
-
- return backend.isdir(path)
-
- @patch_func(os.path, 'isfile')
- def isfile(path):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return isfile._fallback(path)
-
- return backend.isfile(path)
-
- @patch_func(os.path, 'exists')
- def exists(path):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return exists._fallback(path)
- return backend.exists(path)
-
- @patch_func(os, 'mkdir')
- def mkdir(path, *args, **kwargs):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return mkdir._fallback(path, *args, **kwargs)
-
- @patch_func(os, 'makedirs')
- def makedirs(path, *args, **kwargs):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return makedirs._fallback(path, *args, **kwargs)
-
- @patch_func(os, 'listdir')
- def listdir(path):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return listdir._fallback(path)
- return backend.list_dir_or_file(path)
-
- @patch_func(os, 'chmod')
- def chmod(path, *args, **kwargs):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return chmod._fallback(path, *args, **kwargs)
-
- @patch_func(os, 'stat')
- def stat(path, *args, **kwargs):
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return stat._fallback(path, *args, **kwargs)
-
- import glob as glob_pkg
-
- @patch_func(glob_pkg, 'glob')
- def glob(pathname, *, recursive=False):
- backend = get_file_backend(pathname)
- if isinstance(backend, LocalBackend):
- return glob._fallback(pathname, recursive=recursive)
-
- if pathname.endswith('*_optim_states.pt'):
- import os
- pathname = os.path.split(pathname)[0]
- files = backend.list_dir_or_file(pathname, recursive=recursive)
- files = [
- os.path.join(pathname, f) for f in files
- if f.endswith('_optim_states.pt')
- ]
- elif pathname.endswith('*_model_states.pt'):
- import os
- pathname = os.path.split(pathname)[0]
- files = backend.list_dir_or_file(pathname, recursive=recursive)
- files = [
- os.path.join(pathname, f) for f in files
- if f.endswith('_model_states.pt')
- ]
- elif '*' in pathname:
- raise NotImplementedError
- else:
- files = backend.list_dir_or_file(pathname, recursive=recursive)
-
- return files
-
- import filecmp
-
- @patch_func(filecmp, 'cmp')
- def cmp(f1, f2, *args, **kwargs):
- with fileio.get_local_path(f1) as f1, fileio.get_local_path(f2) as f2:
- return cmp._fallback(f1, f2, *args, **kwargs)
-
- import shutil
-
- @patch_func(shutil, 'copy')
- def copy(src, dst, **kwargs):
- from pathlib import Path
-
- if isinstance(src, Path):
- src = str(src).replace(':/', '://')
- if isinstance(dst, Path):
- dst = str(dst).replace(':/', '://')
-
- src_backend = get_file_backend(src)
- dst_backend = get_file_backend(dst)
-
- if isinstance(src_backend, LocalBackend) and isinstance(
- dst_backend, LocalBackend):
- return copy._fallback(src, dst, **kwargs)
- elif isinstance(src_backend, LocalBackend) and isinstance(
- dst_backend, PetrelBackend):
- return dst_backend.copyfile_from_local(str(src), str(dst))
- elif isinstance(src_backend, PetrelBackend) and isinstance(
- dst_backend, LocalBackend):
- return src_backend.copyfile_to_local(str(src), str(dst))
-
- import torch
-
- @patch_func(torch, 'load')
- def load(f, *args, **kwargs):
- if isinstance(f, str):
- f = io.BytesIO(fileio.get(f))
- return load._fallback(f, *args, **kwargs)
-
- @patch_func(torch, 'save')
- def save(obj, f, *args, **kwargs):
- backend = get_file_backend(f)
- if isinstance(backend, LocalBackend):
- return save._fallback(obj, f, *args, **kwargs)
-
- with io.BytesIO() as buffer:
- save._fallback(obj, buffer, *args, **kwargs)
- buffer.seek(0)
- backend.put(buffer, f)
-
- # from tempfile import TemporaryDirectory
- # import os
- # with TemporaryDirectory(dir='/dev/shm') as tmpdir:
- # suffix = os.path.split(f)[-1]
- # tmppath = os.path.join._fallback(tmpdir, suffix)
- # from mmengine import print_log
- # print_log('write to tmp dir', logger='current')
- # save._fallback(obj, tmppath, *args, **kwargs)
- # print_log('write to ceph', logger='current')
-
- # with open(tmppath, 'rb') as buffer:
- # backend.put(buffer, f)
-
- from sentencepiece import SentencePieceProcessor
-
- @patch_func(SentencePieceProcessor, 'LoadFromFile')
- def LoadFromFile(cls, path):
- if path:
- backend = get_file_backend(path)
- if isinstance(backend, LocalBackend):
- return LoadFromFile._fallback(cls, path)
- from tempfile import TemporaryDirectory
- with TemporaryDirectory() as tmpdir:
- local_path = backend.copyfile_to_local(path, tmpdir)
- loaded_file = LoadFromFile._fallback(cls, local_path)
- return loaded_file
- else:
- return LoadFromFile._fallback(cls, path)
-
- try:
- setattr(patch_fileio, '_patched', True)
- yield
- finally:
- for patched_fn in patch_func._backup:
- (module, fn_name_to_wrap, fn_to_wrap) = patched_fn
- setattr(module, fn_name_to_wrap, fn_to_wrap)
- if global_vars is not None and 'open' in global_vars:
- global_vars['open'] = bak_open
- setattr(patch_fileio, '_patched', False)
-
-
-def patch_hf_auto_from_pretrained(petrel_hub):
- if hasattr(patch_hf_auto_from_pretrained, '_patched'):
- return
-
- from peft import PeftModel
- from transformers import (AutoConfig, AutoFeatureExtractor,
- AutoImageProcessor, AutoModelForCausalLM,
- AutoProcessor, AutoTokenizer,
- ImageProcessingMixin, PreTrainedModel,
- PreTrainedTokenizerBase, ProcessorMixin)
- from transformers.models.auto.auto_factory import _BaseAutoModelClass
-
- target_cls = list(_BaseAutoModelClass.__subclasses__())
- target_cls.extend([AutoModelForCausalLM] +
- AutoModelForCausalLM.__subclasses__())
- target_cls.extend([AutoConfig] + AutoConfig.__subclasses__())
- target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__())
- target_cls.extend([AutoImageProcessor] +
- AutoImageProcessor.__subclasses__())
- target_cls.extend([AutoFeatureExtractor] +
- AutoFeatureExtractor.__subclasses__())
- target_cls.extend([AutoProcessor] + AutoProcessor.__subclasses__())
- target_cls.extend([PreTrainedTokenizerBase] +
- PreTrainedTokenizerBase.__subclasses__())
- target_cls.extend([ImageProcessingMixin] +
- ImageProcessingMixin.__subclasses__())
- target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__())
- target_cls.extend([ProcessorMixin] + ProcessorMixin.__subclasses__())
- target_cls.extend([PeftModel] + PeftModel.__subclasses__())
-
- import os
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
- with patch_fileio():
- model_path = pretrained_model_name_or_path
- model_path = os.path.join(petrel_hub, model_path)
- obj = cls._from_pretrained(model_path, *args, **kwargs)
- return obj
-
- for cls in set(target_cls):
- if not hasattr(cls, '_from_pretrained'):
- cls._from_pretrained = cls.from_pretrained
- cls.from_pretrained = from_pretrained
-
- patch_hf_auto_from_pretrained._patched = True
-
-
-def patch_hf_save_pretrained():
- if hasattr(patch_hf_save_pretrained, '_patched'):
- return
-
- import torch
- from peft import PeftModel
- from transformers import (AutoConfig, AutoTokenizer, PreTrainedModel,
- PreTrainedTokenizerBase)
- from transformers.models.auto.auto_factory import _BaseAutoModelClass
-
- target_cls = []
- target_cls.extend([AutoConfig] + AutoConfig.__subclasses__())
- target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__())
- target_cls.extend([PreTrainedTokenizerBase] +
- PreTrainedTokenizerBase.__subclasses__())
- target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__())
-
- target_cls.extend([_BaseAutoModelClass] +
- _BaseAutoModelClass.__subclasses__())
- target_cls.extend([PeftModel] + PeftModel.__subclasses__())
-
- def _patch_wrap(method):
-
- def wrapped_method(self, *args, **kwargs):
-
- with patch_fileio():
- kwargs['save_function'] = torch.save
- kwargs['safe_serialization'] = False
-
- obj = method(self, *args, **kwargs)
- return obj
-
- return wrapped_method
-
- for cls in set(target_cls):
- if hasattr(cls, 'save_pretrained'):
- cls.save_pretrained = _patch_wrap(cls.save_pretrained)
-
- patch_hf_save_pretrained._patched = True
-
-
-def patch_deepspeed_engine():
- if hasattr(patch_deepspeed_engine, '_patched'):
- return
-
- def _copy_recovery_script(self, save_path):
- import os
- from shutil import copyfile
-
- from deepspeed.utils import zero_to_fp32
- from mmengine import PetrelBackend, get_file_backend
- script = 'zero_to_fp32.py'
-
- src = zero_to_fp32.__file__
- dst = os.path.join(save_path, script)
-
- backend = get_file_backend(save_path)
- if isinstance(backend, PetrelBackend):
- backend.copyfile_from_local(src, dst)
- else:
- copyfile(src, dst)
- self._change_recovery_script_permissions(dst)
-
- from deepspeed.runtime.engine import DeepSpeedEngine
- DeepSpeedEngine._copy_recovery_script = _copy_recovery_script
-
- patch_deepspeed_engine._patched = True
diff --git a/code/xtuner/utils/handle_moe_load_and_save.py b/code/xtuner/utils/handle_moe_load_and_save.py
deleted file mode 100644
index 88a3936a84b8de7311e3a00d7e0661a2a3265736..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/handle_moe_load_and_save.py
+++ /dev/null
@@ -1,232 +0,0 @@
-import json
-import os
-import re
-from collections import OrderedDict
-
-import deepspeed
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-from mmengine import print_log
-from transformers.integrations import is_deepspeed_zero3_enabled
-from transformers.modeling_utils import load_state_dict
-from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
- is_safetensors_available)
-
-SUPPORT_MODELS = (
- 'DeepseekV2ForCausalLM',
- 'MixtralForCausalLM',
-)
-
-ORDER_MAPPING = dict(
- DeepseekV2ForCausalLM=dict(down_proj=0, gate_proj=1, up_proj=2),
- MixtralForCausalLM=dict(down_proj=1, gate_proj=0, up_proj=2),
-)
-
-PARAM_NAME_MAPPING = dict(
- DeepseekV2ForCausalLM=dict(
- gate_proj='gate_proj', up_proj='up_proj', down_proj='down_proj'),
- MixtralForCausalLM=dict(gate_proj='w1', up_proj='w3', down_proj='w2'),
-)
-
-
-def print_on_rank0(info):
- if dist.get_rank() == 0:
- print_log(info, 'current')
-
-
-def get_expert_num_per_shard(model):
- for module in model.modules():
- if hasattr(module, 'expert_in_one_shard'):
- return module.expert_in_one_shard
-
-
-def mix_sort(expert_name):
- components = re.findall(r'(\D+|\d+)', expert_name)
- out = [int(comp) if comp.isdigit() else comp for comp in components]
- return tuple(out)
-
-
-def _get_merged_param_name(origin_param_name, expert_num_per_shard):
- split_name = origin_param_name.split('.experts.')
- expert_idx = re.findall(r'\d+', split_name[1])[0]
- expert_idx = int(expert_idx)
- assert expert_idx % expert_num_per_shard == 0
- shard_idx = expert_idx // expert_num_per_shard
- w1w3 = split_name[0] + f'.experts.{shard_idx}.w1w3'
- w2 = split_name[0] + f'.experts.{shard_idx}.w2'
- return w1w3, w2
-
-
-def _merge_experts_weight(state_dict, expert_num_per_shard, order_mapping):
- experts_name = [key for key in state_dict.keys() if '.experts.' in key]
- experts_name = sorted(experts_name, key=mix_sort)
- linear_num_per_expert = 3
- linear_num_per_shard = expert_num_per_shard * linear_num_per_expert
- expert_shard_num = len(experts_name) // linear_num_per_shard
- for shard_idx in range(expert_shard_num):
- begin, end = shard_idx * linear_num_per_shard, (
- shard_idx + 1) * linear_num_per_shard
- experts_name_cur = experts_name[begin:end]
-
- down_proj_weight = [
- state_dict.pop(key)
- for key in experts_name_cur[order_mapping['down_proj']::3]
- ]
- gate_proj_weight = [
- state_dict.pop(key)
- for key in experts_name_cur[order_mapping['gate_proj']::3]
- ]
- up_proj_weight = [
- state_dict.pop(key)
- for key in experts_name_cur[order_mapping['up_proj']::3]
- ]
- w1 = torch.stack(gate_proj_weight)
- w3 = torch.stack(up_proj_weight)
- w1w3 = torch.cat([w1, w3], dim=1)
- assert w1w3.ndim == 3, w1w3.shape
- w2 = torch.stack(down_proj_weight)
- assert w2.ndim == 3, w2.shape
- merged_key_w1w3, merged_key_w2 = _get_merged_param_name(
- experts_name_cur[0], expert_num_per_shard)
- print_on_rank0(f'merged key {merged_key_w1w3}')
- state_dict[merged_key_w1w3] = w1w3
- print_on_rank0(f'merged key {merged_key_w2}')
- state_dict[merged_key_w2] = w2
-
- return
-
-
-def load_state_dict_into_model(model_to_load, pretrained_model_path):
-
- model_name = type(model_to_load).__name__
- if model_name not in SUPPORT_MODELS:
- raise RuntimeError(
- f'Only models in {SUPPORT_MODELS} may need to load pretrained '
- f'weights via `load_state_dict_into_model`, but got {model_name}.')
- order_mapping = ORDER_MAPPING[model_name]
-
- index_file = os.path.join(pretrained_model_path, WEIGHTS_INDEX_NAME)
- safe_index_file = os.path.join(pretrained_model_path,
- SAFE_WEIGHTS_INDEX_NAME)
- index_present = os.path.isfile(index_file)
- safe_index_present = os.path.isfile(safe_index_file)
- assert index_present or (safe_index_present and is_safetensors_available())
- if safe_index_present and is_safetensors_available():
- load_index = safe_index_file
- else:
- load_index = index_file
- with open(load_index, encoding='utf-8') as f:
- index = json.load(f)
- weight_map = index['weight_map']
- unloaded_shard_files = list(set(weight_map.values()))
- unloaded_shard_files.sort(reverse=True)
-
- expert_num_per_shard = get_expert_num_per_shard(model_to_load)
- error_msgs = []
-
- def load(module: nn.Module, state_dict, unloaded_shard_files, prefix=''):
- params_to_gather = []
- param_names = []
- for name, param in module.named_parameters(
- prefix=prefix[:-1], recurse=False):
- while name not in state_dict:
- assert len(unloaded_shard_files) > 0
- shard_file = unloaded_shard_files.pop()
- shard_file = os.path.join(pretrained_model_path, shard_file)
- print_on_rank0(
- f'{name} not in state_dict, loading {shard_file}')
- new_shard = load_state_dict(shard_file, is_quantized=False)
- state_dict.update(new_shard)
- _merge_experts_weight(state_dict, expert_num_per_shard,
- order_mapping)
- params_to_gather.append(param)
- param_names.append(name)
- if len(params_to_gather) > 0:
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
- if is_deepspeed_zero3_enabled():
- with deepspeed.zero.GatheredParameters(
- params_to_gather, modifier_rank=0):
- if dist.get_rank() == 0:
- module._load_from_state_dict(*args)
- else:
- module._load_from_state_dict(*args)
-
- for name in param_names:
- print_on_rank0(f'state_dict pop {name}')
- state_dict.pop(name)
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, state_dict, unloaded_shard_files,
- prefix + name + '.')
-
- state_dict = OrderedDict()
- load(model_to_load, state_dict, unloaded_shard_files, prefix='')
- print_on_rank0(f'{state_dict.keys()}')
- del state_dict
-
- return error_msgs
-
-
-def _get_origin_param_name(merged_param_name, expert_num_per_shard, is_w1w3,
- param_name_mapping):
- split_name = merged_param_name.split('.experts.')
- shard_idx = re.findall(r'\d+', split_name[1])[0]
- shard_idx = int(shard_idx)
- origin_param_names = [None] * (expert_num_per_shard * (1 + int(is_w1w3)))
- expert_idx_begin = expert_num_per_shard * shard_idx
- for i in range(expert_num_per_shard):
- if is_w1w3:
- gate_proj, up_proj = param_name_mapping[
- 'gate_proj'], param_name_mapping['up_proj']
- gate = split_name[
- 0] + f'.experts.{expert_idx_begin + i}.{gate_proj}.weight'
- up = split_name[
- 0] + f'.experts.{expert_idx_begin + i}.{up_proj}.weight'
- origin_param_names[i * 2] = gate
- origin_param_names[i * 2 + 1] = up
- else:
- down_proj = param_name_mapping['down_proj']
- down = split_name[
- 0] + f'.experts.{expert_idx_begin + i}.{down_proj}.weight'
- origin_param_names[i] = down
- return origin_param_names
-
-
-def _split_param(merged_param, is_w1w3):
- if is_w1w3:
- expert_num, _, hidden_dim = merged_param.shape
- merged_param = merged_param.view(expert_num * 2, -1, hidden_dim)
- return torch.unbind(merged_param, dim=0)
- else:
- # (e, hidden_dim, ffn_dim)
- return torch.unbind(merged_param, dim=0)
-
-
-def get_origin_state_dict(state_dict, model):
-
- model_name = type(model).__name__
- if model_name not in SUPPORT_MODELS:
- raise RuntimeError(
- f'Only models in {SUPPORT_MODELS} may need to convert state_dict '
- f'via `get_origin_state_dict` interface, but got {model_name}.')
- param_name_mapping = PARAM_NAME_MAPPING[model_name]
-
- expert_num_per_shard = get_expert_num_per_shard(model)
- experts_param_name = [
- name for name in state_dict.keys() if '.experts.' in name
- ]
- for expert_param_name in experts_param_name:
- print_on_rank0(f'processing {expert_param_name} ...')
- is_w1w3 = expert_param_name.split('.')[-1] == 'w1w3'
- origin_param_names = _get_origin_param_name(expert_param_name,
- expert_num_per_shard,
- is_w1w3,
- param_name_mapping)
- merged_param = state_dict.pop(expert_param_name)
- origin_params = _split_param(merged_param, is_w1w3)
- assert len(origin_param_names) == len(origin_params)
- for name, param in zip(origin_param_names, origin_params):
- state_dict[name] = param
- return state_dict
diff --git a/code/xtuner/utils/rm_utils.py b/code/xtuner/utils/rm_utils.py
deleted file mode 100644
index 6f72c3b9248a036201406ac85e05de3b30f8d4fa..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/rm_utils.py
+++ /dev/null
@@ -1,305 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import time
-from typing import List, Union
-
-import requests
-from transformers import AutoTokenizer
-
-
-class RewardModelClient:
- """This class is used to process the input sequences for the reward
- model."""
-
- def __init__(
- self,
- path,
- max_length=16384,
- max_response_length=4096,
- response_cut_side="left",
- server_type="sglang",
- server_address="127.0.0.1:30000",
- ):
- """
- Args:
- path: Path to the reward model.
- max_length: Maximum length of the input sequence.
- max_response_length: Maximum length of the response sequence.
- response_cut_side: Side to cut the response sequence if it exceeds the maximum length.
- server_type: Type of the server, can be "sglang", "vllm", or "lmdeploy".
- server_address: Address of the reword model server.
- """
- self.rm_name = path.split("/")[-1]
- self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
- # for final reward token and one <|reward|> token and two '\n' tokens
- self.max_length = max_length - 4
- self.max_response_length = max_response_length
- self.response_cut_side = response_cut_side
- self.server_type = server_type
- self.server_address = server_address
-
- def _encode(self, prompt, reference, output, wrapper="sft") -> str:
- """Construct the input string for the reward model.
-
- Args:
- prompt: Prompt.
- reference: Reference trajectory.
- output: Candidate trajectory.
- wrapper: The wrapper type. Can be "sft" or "pretrain".
- Returns:
- The constructed input string for RM.
- """
- p = (
- "\n".join([e["content"] for e in prompt])
- if isinstance(prompt, list)
- else prompt
- )
- r1 = (
- "\n".join([e["content"] for e in reference])
- if isinstance(reference, list)
- else reference
- )
- r2 = (
- "\n".join([e["content"] for e in output])
- if isinstance(output, list)
- else output
- )
-
- p_ids = self.tokenizer.encode(p, add_special_tokens=True)
- r1_ids = self.tokenizer.encode(r1, add_special_tokens=True)
- r2_ids = self.tokenizer.encode(r2, add_special_tokens=True)
-
- if len(r1_ids) > self.max_response_length:
- print(
- f"Reference sequence length {len(r1_ids)} is "
- f"larger than max_response_length {self.max_response_length}",
- )
- if self.response_cut_side == "right":
- r1_ids = r1_ids[: self.max_response_length]
- else:
- r1_ids = r1_ids[-self.max_response_length :]
- if len(r2_ids) > self.max_response_length:
- print(
- f"Output sequence length {len(r2_ids)} is "
- f"larger than max_response_length {self.max_response_length}",
- )
- if self.response_cut_side == "right":
- r2_ids = r2_ids[: self.max_response_length]
- else:
- r2_ids = r2_ids[-self.max_response_length :]
-
- max_prompt_length = (self.max_length - len(r1_ids) - len(r2_ids)) // 2
-
- if len(p_ids) > max_prompt_length:
- print(
- f"Prompt sequence length {len(p_ids)} is "
- f"larger than max_prompt_length {max_prompt_length}",
- )
- p_ids = p_ids[-max_prompt_length:]
-
- p = self.tokenizer.decode(p_ids, skip_special_tokens=True)
- r1 = self.tokenizer.decode(r1_ids, skip_special_tokens=True)
- r2 = self.tokenizer.decode(r2_ids, skip_special_tokens=True)
-
- # Fit the template of RM
- _reference_cat = (
- p + r1 if wrapper == "pretrain" or len(r1) == "" else p + "\n" + r1
- )
- _output_cat = (
- p + r2 if wrapper == "pretrain" or len(r2) == "" else p + "\n" + r2
- )
-
- final_txt = _reference_cat + "<|reward|>" + _output_cat + "[UNUSED_TOKEN_130]"
-
- return final_txt
-
- def encode(self, data) -> Union[str, List[str]]:
- """Encode the input data into a format suitable for RM.
-
- Args:
- data: A dictionary or a list of dictionary containing the keys
- 'prompt', 'reference', 'output', and optionally 'wrapper'.
- Returns:
- The encoded input string for RM.
- """
- if isinstance(data, dict):
- return self._encode(**data)
- elif isinstance(data, list):
- return [
- self._encode(**item) if isinstance(item, dict) else item
- for item in data
- ]
- else:
- raise ValueError(
- "Input data must be a dictionary or a list of dictionaries."
- )
-
- def sglang_request_reward(
- self, data, retry_delay=0.2, max_retries=8
- ) -> List[float]:
- for i in range(max_retries):
- try:
- res = requests.post(
- f"http://{self.server_address}/classify",
- json={
- "model": self.rm_name,
- "text": data,
- },
- )
- rewards = [e["embedding"][0] for e in res.json()]
- return rewards
- except Exception as e:
- print(f"Error requesting reward: {e}")
- print(f"Raw response: {data}")
- time.sleep(retry_delay)
- continue
- print(f"Failed to request reward after {max_retries} retries")
- return None
-
- def vllm_request_reward(self, data, retry_delay=0.2, max_retries=8) -> List[float]:
- for i in range(max_retries):
- try:
- res = requests.post(
- f"http://{self.server_address}/pooling",
- json={
- "input": data,
- },
- )
- rewards = [e["data"][-1][0] for e in res.json()["data"]]
- return rewards
- except Exception as e:
- print(f"Error requesting reward: {e}")
- print(f"Raw response: {data}")
- time.sleep(retry_delay)
- continue
- print(f"Failed to request reward after {max_retries} retries")
- return None
-
- def lmdeploy_request_reward(
- self, data, retry_delay=0.2, max_retries=8
- ) -> List[float]:
- for i in range(max_retries):
- try:
- res = requests.post(
- f"http://{self.server_address}/pooling",
- json={
- "input": data,
- },
- )
- rewards = [e["data"] for e in res.json()["data"]]
- return rewards
- except Exception as e:
- print(f"Error requesting reward: {e}")
- print(f"Raw response: {data}")
- time.sleep(retry_delay)
- continue
- print(f"Failed to request reward after {max_retries} retries")
- return None
-
- def __call__(self, data) -> List[float]:
- """Call the input wrapper to construct the input string for RM.
-
- Args:
- data: A list of dictionaries containing the keys
- 'prompt', 'reference', 'output', and optionally 'wrapper'.
- retry_delay: Delay in seconds before retrying the request.
- max_retries: Maximum number of retries for the request.
- Returns:
- scores: The list of reward scores returned by the RM server.
- If the request fails, it returns None.
- """
- data = self.encode(data)
- if self.server_type == "sglang":
- scores = self.sglang_request_reward(data)
- elif self.server_type == "vllm":
- scores = self.vllm_request_reward(data)
- elif self.server_type == "lmdeploy":
- scores = self.lmdeploy_request_reward(data)
- else:
- raise ValueError(f"Unsupported server type: {self.server_type}")
-
- return scores
-
-
-if __name__ == "__main__":
- # Example usage
- ex1 = [
- {
- "prompt": "How many 'r's are in the word 'strawberry'?",
- "output": "There are three 'r's in the word 'strawberry'.",
- "reference": "3.",
- },
- {
- "prompt": "How many 'r's are in the word 'strawberry'?",
- "reference": "3.",
- "output": "There are two 'r's in the word 'strawberry'.",
- },
- ]
-
- ex2 = [
- {
- "prompt": [
- {
- "role": "user",
- "content": "How many 'r's are in the word 'strawberry'?",
- }
- ],
- "reference": [{"role": "assistant", "content": "3."}],
- "output": [
- {
- "role": "assistant",
- "content": "There are three 'r's in the word 'strawberry'.",
- }
- ],
- },
- {
- "prompt": [
- {
- "role": "user",
- "content": "How many 'r's are in the word 'strawberry'?",
- }
- ],
- "reference": [{"role": "assistant", "content": "3."}],
- "output": [
- {
- "role": "assistant",
- "content": "There are two 'r's in the word 'strawberry'.",
- }
- ],
- },
- ]
-
- # sglang
- client = RewardModelClient(
- "internlm/POLAR-7B", server_type="sglang", server_address="127.0.0.1:30000"
- )
-
- scores = client(ex1)
- print(scores)
-
- encoded_text = client.encode(ex2)
- scores = client.sglang_request_reward(encoded_text)
- print(scores)
-
- # vllm
- client = RewardModelClient(
- "internlm/POLAR-7B", server_type="vllm", server_address="127.0.0.1:30000"
- )
-
- scores = client(ex1)
- print(scores)
-
- encoded_text = client.encode(ex2)
- scores = client.vllm_request_reward(encoded_text)
- print(scores)
-
- # lmdeploy
- client = RewardModelClient(
- "internlm/POLAR-7B", server_type="lmdeploy", server_address="127.0.0.1:30000"
- )
-
- scores = client(ex1)
- print(scores)
-
- encoded_text = client.encode(ex2)
- scores = client.lmdeploy_request_reward(encoded_text)
- print(scores)
diff --git a/code/xtuner/utils/stop_criteria.py b/code/xtuner/utils/stop_criteria.py
deleted file mode 100644
index 954cc9d700af18f4951eab4fa881cc34d900f365..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/stop_criteria.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from transformers import StoppingCriteria
-
-
-class StopWordStoppingCriteria(StoppingCriteria):
- """StopWord stopping criteria."""
-
- def __init__(self, tokenizer, stop_word):
- self.tokenizer = tokenizer
- self.stop_word = stop_word
- self.length = len(self.stop_word)
-
- def __call__(self, input_ids, *args, **kwargs) -> bool:
- cur_text = self.tokenizer.decode(input_ids[0])
- cur_text = cur_text.replace('\r', '').replace('\n', '')
- return cur_text[-self.length:] == self.stop_word
diff --git a/code/xtuner/utils/templates.py b/code/xtuner/utils/templates.py
deleted file mode 100644
index 770775e293403cd9ea90d06047a5f14e72808fe0..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/templates.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.config import ConfigDict
-
-# - Turn 0: SYSTEM + INSTRUCTION, [output + SUFFIX], SEP
-# - Turn 1: INSTRUCTION, [output + SUFFIX], SEP
-# - Turn ...
-# Note: [] means having supervised loss during the fine-tuning
-PROMPT_TEMPLATE = ConfigDict(
- default=dict(
- SYSTEM='<|System|>:{system}\n',
- INSTRUCTION='<|User|>:{input}\n<|Bot|>:',
- SEP='\n'),
- zephyr=dict(
- SYSTEM='<|system|>\n{system}\n',
- INSTRUCTION='<|user|>\n{input}\n<|assistant|>\n',
- SEP='\n'),
- internlm_chat=dict(
- SYSTEM='<|System|>:{system}\n',
- INSTRUCTION='<|User|>:{input}\n<|Bot|>:',
- SUFFIX='',
- SUFFIX_AS_EOS=True,
- SEP='\n',
- STOP_WORDS=['']),
- internlm2_chat=dict(
- SYSTEM='<|im_start|>system\n{system}<|im_end|>\n',
- INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
- '<|im_start|>assistant\n'),
- SUFFIX='<|im_end|>',
- SUFFIX_AS_EOS=True,
- SEP='\n',
- STOP_WORDS=['<|im_end|>']),
- moss_sft=dict(
- SYSTEM='{system}\n',
- INSTRUCTION='<|Human|>: {input}\n',
- SEP='\n',
- STOP_WORDS=['', '']),
- llama2_chat=dict(
- SYSTEM=(
- '[INST] <>\n You are a helpful, respectful and honest '
- 'assistant. Always answer as helpfully as possible, while being '
- 'safe. Your answers should not include any harmful, unethical, '
- 'racist, sexist, toxic, dangerous, or illegal content. Please '
- 'ensure that your responses are socially unbiased and positive in '
- 'nature.\n{system}\n<>\n [/INST] '),
- INSTRUCTION='[INST] {input} [/INST]',
- SEP='\n'),
- code_llama_chat=dict(
- SYSTEM='{system}\n', INSTRUCTION='[INST] {input} [/INST]'),
- chatglm2=dict(
- SYSTEM='{system}\n',
- INSTRUCTION='[Round {round}]\n\n问:{input}\n\n答:',
- SEP='\n\n'),
- chatglm3=dict(
- SYSTEM='<|system|>\n{system}',
- INSTRUCTION='<|user|>\n{input}<|assistant|>\n',
- SEP='\n'),
- qwen_chat=dict(
- SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'),
- INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
- '<|im_start|>assistant\n'),
- SUFFIX='<|im_end|>',
- SUFFIX_AS_EOS=True,
- SEP='\n',
- STOP_WORDS=['<|im_end|>', '<|endoftext|>']),
- qwen3_no_think_chat = dict(
- SYSTEM = '<|im_start|>system\n{system}\n/no_think<|im_end|>\n',
- INSTRUCTION = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n',
- SUFFIX = '<|im_end|>',
- SUFFIX_AS_EOS = True,
- SEP = '\n',
- STOP_WORDS = ['<|im_end|>', '<|endoftext|>'],
- ),
- baichuan_chat=dict(
- SYSTEM='{system}\n',
- INSTRUCTION='{input}',
- SEP='\n'),
- baichuan2_chat=dict(
- SYSTEM='{system}\n',
- INSTRUCTION='{input}',
- SEP='\n'),
- wizardlm=dict(
- SYSTEM=('A chat between a curious user and an artificial '
- 'intelligence assistant. The assistant gives '
- 'helpful, detailed, and polite answers to the '
- 'user\'s questions. {system}\n '),
- INSTRUCTION=('USER: {input} ASSISTANT:'),
- SEP='\n'),
- wizardcoder=dict(
- SYSTEM=(
- 'Below is an instruction that describes a task. '
- 'Write a response that appropriately completes the request.\n\n'
- '{system}\n '),
- INSTRUCTION=('### Instruction:\n{input}\n\n### Response:'),
- SEP='\n\n'),
- vicuna=dict(
- SYSTEM=('A chat between a curious user and an artificial '
- 'intelligence assistant. The assistant gives '
- 'helpful, detailed, and polite answers to the '
- 'user\'s questions. {system}\n '),
- INSTRUCTION=('USER: {input} ASSISTANT:'),
- SEP='\n'),
- deepseek_coder=dict(
- SYSTEM=('You are an AI programming assistant, utilizing '
- 'the DeepSeek Coder model, developed by DeepSeek'
- 'Company, and you only answer questions related '
- 'to computer science. For politically sensitive '
- 'questions, security and privacy issues, and '
- 'other non-computer science questions, you will '
- 'refuse to answer. {system}\n'),
- INSTRUCTION=('### Instruction:\n{input}\n### Response:\n'),
- SEP='\n'),
- # TODO: deprecation, v0.2.0
- deepseekcoder=dict(
- SYSTEM=('You are an AI programming assistant, utilizing '
- 'the DeepSeek Coder model, developed by DeepSeek'
- 'Company, and you only answer questions related '
- 'to computer science. For politically sensitive '
- 'questions, security and privacy issues, and '
- 'other non-computer science questions, you will '
- 'refuse to answer. {system}\n'),
- INSTRUCTION=('### Instruction:\n{input}\n### Response:\n'),
- SEP='\n'),
- deepseek_moe=dict(
- SYSTEM=('[INST] {system} [/INST]\n'),
- INSTRUCTION=('[INST] {input} [/INST]'),
- SEP='\n'),
- deepseek_v2=dict(
- SYSTEM='{system}\n\n',
- INSTRUCTION='User: {input}\n\nAssistant: ',
- SUFFIX='<|end▁of▁sentence|>',
- SUFFIX_AS_EOS=True,
- STOP_WORDS=['<|end▁of▁sentence|>']),
- mistral=dict(
- SYSTEM=('[INST] {system} [/INST]\n'),
- INSTRUCTION=('[INST] {input} [/INST]'),
- SEP='\n'),
- mixtral=dict(
- SYSTEM=('[INST] {system} [/INST]\n'),
- INSTRUCTION=('[INST] {input} [/INST]'),
- SEP='\n'),
- minicpm=dict(INSTRUCTION=('<用户> {input} '), SEP='\n'),
- gemma=dict(
- # `system` field is extended by xtuner
- SYSTEM=('system\n{system}\n'),
- INSTRUCTION=('user\n{input}\n'
- 'model\n'),
- SUFFIX='',
- SUFFIX_AS_EOS=False,
- SEP='\n',
- STOP_WORDS=['']),
- cohere_chat=dict(
- SYSTEM=('<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}'
- '<|END_OF_TURN_TOKEN|>'),
- INSTRUCTION=(
- '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{input}<|END_OF_TURN_TOKEN|>'
- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'),
- SUFFIX='<|END_OF_TURN_TOKEN|>',
- SUFFIX_AS_EOS=True,
- STOP_WORDS=['<|END_OF_TURN_TOKEN|>']),
- llama3_chat=dict(
- SYSTEM=('<|start_header_id|>system<|end_header_id|>\n\n'
- '{system}<|eot_id|>'),
- INSTRUCTION=(
- '<|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|>'
- '<|start_header_id|>assistant<|end_header_id|>\n\n'),
- SUFFIX='<|eot_id|>',
- SUFFIX_AS_EOS=True,
- STOP_WORDS=['<|eot_id|>']),
- phi3_chat=dict(
- SYSTEM='<|system|>\n{system}<|end|>\n',
- INSTRUCTION='<|user|>\n{input}<|end|>\n<|assistant|>\n',
- SUFFIX='<|end|>',
- SUFFIX_AS_EOS=True,
- SEP='\n',
- STOP_WORDS=['<|end|>']),
-)
-
-SYSTEM_TEMPLATE = ConfigDict(
- moss_sft=('You are an AI assistant whose name is {bot_name}.\n'
- 'Capabilities and tools that {bot_name} can possess.\n'
- '- Inner thoughts: enabled.\n'
- '- Web search: enabled. API: Search(query)\n'
- '- Calculator: enabled. API: Calculate(expression)\n'
- '- Equation solver: enabled. API: Solve(equation)\n'
- '- Text-to-image: disabled.\n'
- '- Image edition: disabled.\n'
- '- Text-to-speech: disabled.\n'),
- alpaca=('Below is an instruction that describes a task. '
- 'Write a response that appropriately completes the request.\n'),
- arxiv_gentile=('If you are an expert in writing papers, please generate '
- "a good paper title for this paper based on other authors' "
- 'descriptions of their abstracts.\n'),
- colorist=('You are a professional color designer. Please provide the '
- 'corresponding colors based on the description of Human.\n'),
- coder=('You are a professional programer. Please provide the '
- 'corresponding code based on the description of Human.\n'),
- lawyer='你现在是一名专业的中国律师,请根据用户的问题给出准确、有理有据的回复。\n',
- medical='如果你是一名医生,请根据患者的描述回答医学问题。\n',
- sql=('If you are an expert in SQL, please generate a good SQL Query '
- 'for Question based on the CREATE TABLE statement.\n'),
-)
diff --git a/code/xtuner/utils/zero_to_any_dtype.py b/code/xtuner/utils/zero_to_any_dtype.py
deleted file mode 100644
index f26a8372a4daadf4cdc09ec40c4be06b83e7d9e4..0000000000000000000000000000000000000000
--- a/code/xtuner/utils/zero_to_any_dtype.py
+++ /dev/null
@@ -1,696 +0,0 @@
-#!/usr/bin/env python
-
-# Copyright (c) Microsoft Corporation.
-# SPDX-License-Identifier: Apache-2.0
-
-# DeepSpeed Team
-
-# This script extracts consolidated weights from a zero 1, 2 and 3 DeepSpeed
-# checkpoints. It gets copied into the top level checkpoint dir, so the user
-# can easily do the conversion at any point in the future. Once extracted, the
-# weights don't require DeepSpeed and can be used in any application.
-#
-# example: python zero_to_any_dtype.py . pytorch_model.bin
-
-import argparse
-import glob
-import math
-import os
-import re
-from collections import OrderedDict
-from dataclasses import dataclass
-
-import torch
-# yapf: disable
-from deepspeed.checkpoint.constants import (BUFFER_NAMES, DS_VERSION,
- FP32_FLAT_GROUPS,
- FROZEN_PARAM_FRAGMENTS,
- FROZEN_PARAM_SHAPES,
- OPTIMIZER_STATE_DICT, PARAM_SHAPES,
- PARTITION_COUNT,
- SINGLE_PARTITION_OF_FP32_GROUPS,
- ZERO_STAGE)
-# while this script doesn't use deepspeed to recover data, since the
-# checkpoints are pickled with DeepSpeed data structures it has to be
-# available in the current python environment.
-from deepspeed.utils import logger
-from tqdm import tqdm
-
-# yapf: enable
-
-
-@dataclass
-class zero_model_state:
- buffers: dict()
- param_shapes: dict()
- shared_params: list
- ds_version: int
- frozen_param_shapes: dict()
- frozen_param_fragments: dict()
-
-
-debug = 0
-
-# load to cpu
-device = torch.device('cpu')
-
-DEFAULT_DTYPE = torch.float16
-
-
-def atoi(text):
- return int(text) if text.isdigit() else text
-
-
-def natural_keys(text):
- """alist.sort(key=natural_keys) sorts in human order
- http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's
- implementation in the comments)"""
- return [atoi(c) for c in re.split(r'(\d+)', text)]
-
-
-def get_model_state_file(checkpoint_dir, zero_stage):
- if not os.path.isdir(checkpoint_dir):
- raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
-
- # there should be only one file
- if zero_stage <= 2:
- file = os.path.join(checkpoint_dir, 'mp_rank_00_model_states.pt')
- elif zero_stage == 3:
- file = os.path.join(checkpoint_dir,
- 'zero_pp_rank_0_mp_rank_00_model_states.pt')
-
- if not os.path.exists(file):
- raise FileNotFoundError(f"can't find model states file at '{file}'")
-
- return file
-
-
-def get_checkpoint_files(checkpoint_dir, glob_pattern):
- # XXX: need to test that this simple glob rule works for multi-node
- # setup too
- ckpt_files = sorted(
- glob.glob(os.path.join(checkpoint_dir, glob_pattern)),
- key=natural_keys)
-
- if len(ckpt_files) == 0:
- raise FileNotFoundError(
- f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
-
- return ckpt_files
-
-
-def get_optim_files(checkpoint_dir):
- return get_checkpoint_files(checkpoint_dir, '*_optim_states.pt')
-
-
-def get_model_state_files(checkpoint_dir):
- return get_checkpoint_files(checkpoint_dir, '*_model_states.pt')
-
-
-def parse_model_states(files, dtype=DEFAULT_DTYPE):
- zero_model_states = []
- for file in files:
- state_dict = torch.load(file, map_location=device, weights_only=False)
-
- if BUFFER_NAMES not in state_dict:
- raise ValueError(f'{file} is not a model state checkpoint')
- buffer_names = state_dict[BUFFER_NAMES]
- if debug:
- print('Found buffers:', buffer_names)
-
- buffers = {
- k: v.to(dtype)
- for k, v in state_dict['module'].items() if k in buffer_names
- }
- param_shapes = state_dict[PARAM_SHAPES]
-
- # collect parameters that are included in param_shapes
- param_names = []
- for s in param_shapes:
- for name in s.keys():
- param_names.append(name)
-
- # update with frozen parameters
- frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
- if frozen_param_shapes is not None:
- if debug:
- print(f'Found frozen_param_shapes: {frozen_param_shapes}')
- param_names += list(frozen_param_shapes.keys())
-
- # handle shared params
- shared_params = [[k, v]
- for k, v in state_dict['shared_params'].items()]
-
- ds_version = state_dict.get(DS_VERSION, None)
-
- frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
-
- z_model_state = zero_model_state(
- buffers=buffers,
- param_shapes=param_shapes,
- shared_params=shared_params,
- ds_version=ds_version,
- frozen_param_shapes=frozen_param_shapes,
- frozen_param_fragments=frozen_param_fragments)
- zero_model_states.append(z_model_state)
-
- return zero_model_states
-
-
-@torch.no_grad()
-def parse_optim_states(files, ds_checkpoint_dir, dtype=DEFAULT_DTYPE):
-
- zero_stage = None
- world_size = None
- total_files = len(files)
- flat_groups = []
- for f in tqdm(files, desc='Load Checkpoints'):
- state_dict = torch.load(f, map_location=device, weights_only=False)
- if ZERO_STAGE not in state_dict[OPTIMIZER_STATE_DICT]:
- raise ValueError(f'{f} is not a zero checkpoint')
-
- zero_stage = state_dict[OPTIMIZER_STATE_DICT][ZERO_STAGE]
- world_size = state_dict[OPTIMIZER_STATE_DICT][PARTITION_COUNT]
-
- # the groups are named differently in each stage
- if zero_stage <= 2:
- fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
- elif zero_stage == 3:
- fp32_groups_key = FP32_FLAT_GROUPS
- else:
- raise ValueError(f'unknown zero stage {zero_stage}')
-
- # immediately discard the potentially huge 2 optimizer states as we
- # only care for fp32 master weights and also handle the case where it
- # was already removed by another helper script
- state_dict['optimizer_state_dict'].pop('optimizer_state_dict', None)
- fp32_groups = state_dict['optimizer_state_dict'].pop(fp32_groups_key)
- if zero_stage <= 2:
- flat_groups.append([param.to(dtype) for param in fp32_groups])
- elif zero_stage == 3:
- # if there is more than one param group, there will be multiple
- # flattened tensors - one flattened tensor per group - for
- # simplicity merge them into a single tensor
-
- # XXX: could make the script more memory efficient for when there
- # are multiple groups - it will require matching the sub-lists of
- # param_shapes for each param group flattened tensor
- flat_groups.append(torch.cat(fp32_groups, 0).to(dtype))
-
- # For ZeRO-2 each param group can have different partition_count as data
- # parallelism for expert parameters can be different from data parallelism
- # for non-expert parameters. So we can just use the max of the
- # partition_count to get the dp world_size.
- if type(world_size) is list:
- world_size = max(world_size)
-
- if world_size != total_files:
- raise ValueError(
- f"Expected {world_size} of '*_optim_states.pt' under "
- f"'{ds_checkpoint_dir}' but found {total_files} files. "
- 'Possibly due to an overwrite of an old checkpoint, '
- "or a checkpoint didn't get saved by one or more processes.")
-
- return zero_stage, world_size, flat_groups
-
-
-def _get_state_dict_from_zero_checkpoint(ds_checkpoint_dir,
- exclude_frozen_parameters,
- dtype=DEFAULT_DTYPE):
- """Returns state_dict reconstructed from ds checkpoint.
-
- Args:
- - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder
- (where the optimizer files are)
- """
- print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
-
- optim_files = get_optim_files(ds_checkpoint_dir)
- zero_stage, world_size, flat_groups = parse_optim_states(
- optim_files, ds_checkpoint_dir, dtype)
- print(f'Detected checkpoint of type zero stage {zero_stage}, '
- f'world_size: {world_size}')
-
- model_files = get_model_state_files(ds_checkpoint_dir)
-
- zero_model_states = parse_model_states(model_files)
- print(f'Parsing checkpoint created by deepspeed=='
- f'{zero_model_states[0].ds_version}')
-
- if zero_stage <= 2:
- return _get_state_dict_from_zero2_checkpoint(
- world_size, flat_groups, zero_model_states,
- exclude_frozen_parameters)
- elif zero_stage == 3:
- return _get_state_dict_from_zero3_checkpoint(
- world_size, flat_groups, zero_model_states,
- exclude_frozen_parameters)
-
-
-def _zero2_merge_frozen_params(state_dict, zero_model_states):
- if zero_model_states[0].frozen_param_shapes is None or len(
- zero_model_states[0].frozen_param_shapes) == 0:
- return
-
- frozen_param_shapes = zero_model_states[0].frozen_param_shapes
- frozen_param_fragments = zero_model_states[0].frozen_param_fragments
-
- if debug:
- num_elem = sum(s.numel() for s in frozen_param_shapes.values())
- print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
-
- wanted_params = len(frozen_param_shapes)
- wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
- avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
- print(f'Frozen params: Have {avail_numel} numels to process.')
- print(f'Frozen params: Need {wanted_numel} numels in '
- f'{wanted_params} params')
-
- total_params = 0
- total_numel = 0
- for name, shape in frozen_param_shapes.items():
- total_params += 1
- unpartitioned_numel = shape.numel()
- total_numel += unpartitioned_numel
-
- state_dict[name] = frozen_param_fragments[name]
-
- if debug:
- print(f'{name} full shape: {shape} unpartitioned numel '
- f'{unpartitioned_numel} ')
-
- print(f'Reconstructed Frozen state dict with {total_params} params '
- f'{total_numel} elements')
-
-
-def _has_callable(obj, fn):
- attr = getattr(obj, fn, None)
- return callable(attr)
-
-
-def _zero2_merge_trainable_params(state_dict, world_size, flat_groups,
- zero_model_states):
- param_shapes = zero_model_states[0].param_shapes
-
- # Reconstruction protocol:
- #
- # XXX: document this
-
- if debug:
- for i in range(world_size):
- for j in range(len(flat_groups[0])):
- print(f'flat_groups[{i}][{j}].shape={flat_groups[i][j].shape}')
-
- # XXX: memory usage doubles here (zero2)
- num_param_groups = len(flat_groups[0])
- merged_single_partition_of_groups = []
- for i in range(num_param_groups):
- merged_partitions = [sd[i] for sd in flat_groups]
- full_single_vector = torch.cat(merged_partitions, 0)
- merged_single_partition_of_groups.append(full_single_vector)
- avail_numel = sum([
- full_single_vector.numel()
- for full_single_vector in merged_single_partition_of_groups
- ])
-
- if debug:
- wanted_params = sum([len(shapes) for shapes in param_shapes])
- wanted_numel = sum([
- sum(shape.numel() for shape in shapes.values())
- for shapes in param_shapes
- ])
- # not asserting if there is a mismatch due to possible padding
- print(f'Have {avail_numel} numels to process.')
- print(f'Need {wanted_numel} numels in {wanted_params} params.')
-
- # params
- # XXX: for huge models that can't fit into the host's RAM we will have to
- # recode this to support out-of-core computing solution
- total_numel = 0
- total_params = 0
- for shapes, full_single_vector in zip(param_shapes,
- merged_single_partition_of_groups):
- offset = 0
- avail_numel = full_single_vector.numel()
- for name, shape in shapes.items():
-
- unpartitioned_numel = shape.numel() if _has_callable(
- shape, 'numel') else math.prod(shape)
- total_numel += unpartitioned_numel
- total_params += 1
-
- if debug:
- print(f'{name} full shape: {shape} unpartitioned numel '
- f'{unpartitioned_numel} ')
- state_dict[name] = full_single_vector.narrow(
- 0, offset, unpartitioned_numel).view(shape)
- offset += unpartitioned_numel
-
- # Z2 started to align to 2*world_size to improve nccl performance.
- # Therefore both offset and avail_numel can differ by anywhere between
- # 0..2*world_size. Due to two unrelated complex paddings performed in
- # the code it's almost impossible to predict the exact numbers w/o the
- # live optimizer object, so we are checking that the numbers are
- # within the right range
- align_to = 2 * world_size
-
- def zero2_align(x):
- return align_to * math.ceil(x / align_to)
-
- if debug:
- print(f'original offset={offset}, avail_numel={avail_numel}')
-
- offset = zero2_align(offset)
- avail_numel = zero2_align(avail_numel)
-
- if debug:
- print(f'aligned offset={offset}, avail_numel={avail_numel}')
-
- # Sanity check
- if offset != avail_numel:
- raise ValueError(f'consumed {offset} numels out of {avail_numel} '
- '- something is wrong')
-
- print(f'Reconstructed state dict with {total_params} params '
- f'{total_numel} elements')
-
-
-def _get_state_dict_from_zero2_checkpoint(world_size, flat_groups,
- zero_model_states,
- exclude_frozen_parameters):
- state_dict = OrderedDict()
-
- # buffers
- buffers = zero_model_states[0].buffers
- state_dict.update(buffers)
- if debug:
- print(f'added {len(buffers)} buffers')
-
- if not exclude_frozen_parameters:
- _zero2_merge_frozen_params(state_dict, zero_model_states)
-
- _zero2_merge_trainable_params(state_dict, world_size, flat_groups,
- zero_model_states)
-
- # recover shared parameters
- for pair in zero_model_states[0].shared_params:
- if pair[1] in state_dict:
- state_dict[pair[0]] = state_dict[pair[1]]
-
- return state_dict
-
-
-def zero3_partitioned_param_info(unpartitioned_numel, world_size):
- remainder = unpartitioned_numel % world_size
- padding_numel = (world_size - remainder) if remainder else 0
- partitioned_numel = math.ceil(unpartitioned_numel / world_size)
- return partitioned_numel, padding_numel
-
-
-def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
- if zero_model_states[0].frozen_param_shapes is None or len(
- zero_model_states[0].frozen_param_shapes) == 0:
- return
-
- if debug:
- for i in range(world_size):
- num_elem = sum(
- s.numel()
- for s in zero_model_states[i].frozen_param_fragments.values())
- print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
-
- frozen_param_shapes = zero_model_states[0].frozen_param_shapes
- wanted_params = len(frozen_param_shapes)
- wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
- avail_numel = sum([
- p.numel()
- for p in zero_model_states[0].frozen_param_fragments.values()
- ]) * world_size
- print(f'Frozen params: Have {avail_numel} numels to process.')
- print(f'Frozen params: Need {wanted_numel} numels in '
- f'{wanted_params} params')
-
- total_params = 0
- total_numel = 0
- for name, shape in zero_model_states[0].frozen_param_shapes.items():
- total_params += 1
- unpartitioned_numel = shape.numel()
- total_numel += unpartitioned_numel
-
- param_frags = tuple(model_state.frozen_param_fragments[name]
- for model_state in zero_model_states)
- state_dict[name] = torch.cat(param_frags, 0).narrow(
- 0, 0, unpartitioned_numel).view(shape) # noqa: E501
-
- _partitioned = zero3_partitioned_param_info(unpartitioned_numel,
- world_size)
- partitioned_numel, partitioned_padding_numel = _partitioned
- if debug:
- print(f'Frozen params: {total_params} {name} full shape: {shape} '
- f'partition0 numel={partitioned_numel} '
- f'partitioned_padding_numel={partitioned_padding_numel}')
-
- print(f'Reconstructed Frozen state dict with {total_params} params '
- f'{total_numel} elements')
-
-
-def _zero3_merge_trainable_params(state_dict, world_size, flat_groups,
- zero_model_states):
- param_shapes = zero_model_states[0].param_shapes
- avail_numel = flat_groups[0].numel() * world_size
- # Reconstruction protocol: For zero3 we need to zip the partitions
- # together at boundary of each param, re-consolidating each param, while
- # dealing with padding if any
-
- # merge list of dicts, preserving order
- param_shapes = {k: v for d in param_shapes for k, v in d.items()}
-
- if debug:
- for i in range(world_size):
- print(f'flat_groups[{i}].shape={flat_groups[i].shape}')
-
- wanted_params = len(param_shapes)
- wanted_numel = sum(shape.numel() for shape in param_shapes.values())
- # not asserting if there is a mismatch due to possible padding
- avail_numel = flat_groups[0].numel() * world_size
- print(f'Trainable params: Have {avail_numel} numels to process.')
- print(f'Trainable params: Need {wanted_numel} numels in '
- f'{wanted_params} params.')
-
- offset = 0
- total_numel = 0
- total_params = 0
- partitioned_sizes = []
- for name, shape in param_shapes.items():
-
- unpartitioned_numel = shape.numel()
- total_numel += unpartitioned_numel
- total_params += 1
-
- _info = zero3_partitioned_param_info(unpartitioned_numel, world_size)
-
- partitioned_numel, partitioned_padding_numel = _info
- partitioned_sizes.append(partitioned_numel)
- if debug:
- print(
- f'Trainable params: {total_params} {name} full shape: {shape} '
- f'partition0 numel={partitioned_numel} '
- f'partitioned_padding_numel={partitioned_padding_numel}')
-
- offset += partitioned_numel
-
- offset *= world_size
-
- # Sanity check
- if offset != avail_numel:
- raise ValueError(f'consumed {offset} numels out of {avail_numel} '
- '- something is wrong')
-
- mat_chunks = []
- for rank in range(world_size):
- rank_chunks = flat_groups.pop(0).split(partitioned_sizes)
- rank_chunks = [tensor.clone() for tensor in rank_chunks]
- mat_chunks.append(rank_chunks)
-
- for name, shape in tqdm(
- param_shapes.items(), desc='Gather Sharded Weights'):
-
- pad_flat_param_chunks = []
- for rank in range(world_size):
- pad_flat_param_chunks.append(mat_chunks[rank].pop(0))
-
- pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
-
- # Because pad_flat_param_chunks is a list, it is necessary to manually
- # release the tensors in the list; Python will not automatically do so.
- for rank in range(world_size):
- pad_flat_param_chunks.pop()
-
- param = pad_flat_param[:shape.numel()].view(shape)
- state_dict[name] = param
-
- print(f'Reconstructed Trainable state dict with {total_params} params '
- f'{total_numel} elements')
-
-
-def _get_state_dict_from_zero3_checkpoint(world_size, flat_groups,
- zero_model_states,
- exclude_frozen_parameters):
- state_dict = OrderedDict()
-
- # buffers
- buffers = zero_model_states[0].buffers
- state_dict.update(buffers)
- if debug:
- print(f'added {len(buffers)} buffers')
-
- if not exclude_frozen_parameters:
- _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
-
- _zero3_merge_trainable_params(state_dict, world_size, flat_groups,
- zero_model_states)
-
- # recover shared parameters
- for pair in zero_model_states[0].shared_params:
- if pair[1] in state_dict:
- state_dict[pair[0]] = state_dict[pair[1]]
-
- return state_dict
-
-
-def get_state_dict_from_zero_checkpoint(checkpoint_dir,
- tag=None,
- exclude_frozen_parameters=False,
- dtype=DEFAULT_DTYPE):
- # flake8: noqa
- """Convert ZeRO 2 or 3 checkpoint into a single consolidated state_dict
- that can be loaded with ``load_state_dict()`` and used for training without
- DeepSpeed or shared with others, for example via a model hub.
-
- Args:
- - ``checkpoint_dir``: path to the desired checkpoint folder
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint.
- If not provided will attempt to load tag in 'latest' file.
- e.g., ``global_step14``
- - ``exclude_frozen_parameters``: exclude frozen parameters
-
- Returns:
- - pytorch ``state_dict``
-
- Note: this approach may not work if your application doesn't have
- sufficient free CPU memory and you may need to use the offline approach
- using the ``zero_to_any_dtype.py`` script that is saved with the
- checkpoint.
-
- A typical usage might be ::
-
- from xtuner.utils.zero_to_any_dtype import get_state_dict_from_zero_checkpoint
- # do the training and checkpoint saving
- state_dict = get_state_dict_from_zero_checkpoint(checkpoint_dir, dtype=torch.float16) # already on cpu
- model = model.cpu() # move to cpu
- model.load_state_dict(state_dict)
- # submit to model hub or save the model to share with others
-
- In this example the ``model`` will no longer be usable in the deepspeed
- context of the same application. i.e. you will need to re-initialize the
- deepspeed engine, since ``model.load_state_dict(state_dict)`` will remove
- all the deepspeed magic from it.
-
- If you want it all done for you, use
- ``load_state_dict_from_zero_checkpoint`` instead.
- """
- # flake8: noqa
- if tag is None:
- latest_path = os.path.join(checkpoint_dir, 'latest')
- if os.path.isfile(latest_path):
- with open(latest_path) as fd:
- tag = fd.read().strip()
- else:
- raise ValueError(f"Unable to find 'latest' file at {latest_path}")
-
- ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
-
- if not os.path.isdir(ds_checkpoint_dir):
- raise FileNotFoundError(
- f"Directory '{ds_checkpoint_dir}' doesn't exist")
-
- return _get_state_dict_from_zero_checkpoint(ds_checkpoint_dir,
- exclude_frozen_parameters,
- dtype)
-
-
-def convert_zero_checkpoint_to_state_dict(checkpoint_dir,
- output_file,
- tag=None,
- exclude_frozen_parameters=False,
- dtype=DEFAULT_DTYPE):
- """Convert ZeRO 2 or 3 checkpoint into a single consolidated ``state_dict``
- file that can be loaded with ``torch.load(file)`` + ``load_state_dict()``
- and used for training without DeepSpeed.
-
- Args:
- - ``checkpoint_dir``: path to the desired checkpoint folder.
- (one that contains the tag-folder, like ``global_step14``)
- - ``output_file``: path to the pytorch state_dict output file
- (e.g. path/pytorch_model.bin)
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint.
- If not provided will attempt to load tag in the file named
- ``latest`` in the checkpoint folder, e.g., ``global_step14``
- - ``exclude_frozen_parameters``: exclude frozen parameters
- """
-
- state_dict = get_state_dict_from_zero_checkpoint(
- checkpoint_dir, tag, exclude_frozen_parameters, dtype)
- print(f'Saving {dtype} state dict to {output_file}')
- torch.save(state_dict, output_file)
-
-
-def load_state_dict_from_zero_checkpoint(model,
- checkpoint_dir,
- tag=None,
- dtype=DEFAULT_DTYPE):
-
- # flake8: noqa
- """
- 1. Put the provided model to cpu
- 2. Convert ZeRO 2 or 3 checkpoint into a single consolidated ``state_dict``
- 3. Load it into the provided model
-
- Args:
- - ``model``: the model object to update
- - ``checkpoint_dir``: path to the desired checkpoint folder. (one that
- contains the tag-folder, like ``global_step14``)
- - ``tag``: checkpoint tag used as a unique identifier for checkpoint.
- If not provided will attempt to load tag in the file named
- ``latest`` in the checkpoint folder, e.g., ``global_step14``
-
- Returns:
- - ``model`: modified model
-
- Make sure you have plenty of CPU memory available before you call this
- function. If you don't have enough use the ``zero_to_any_dtype.py``
- utility to do the conversion. You will find it conveniently placed for you
- in the checkpoint folder.
-
- A typical usage might be ::
-
- from xtuner.utils.zero_to_any_dtype import load_state_dict_from_zero_checkpoint
- model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir, dtype=torch.float16)
- # submit to model hub or save the model to share with others
-
- Note, that once this was run, the ``model`` will no longer be usable in
- the deepspeed context of the same application. i.e. you will need to
- re-initialize the deepspeed engine, since
- ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic
- from it.
- """
- # flake8: noqa
- logger.info(f'Extracting {dtype} weights')
- state_dict = get_state_dict_from_zero_checkpoint(
- checkpoint_dir, tag, dtype=dtype)
-
- logger.info(f'Overwriting model with {dtype} weights')
- model = model.cpu()
- model.load_state_dict(state_dict, strict=False)
-
- return model
diff --git a/code/xtuner/version.py b/code/xtuner/version.py
deleted file mode 100644
index e4669c1880af551fc52eae2b826adfdd60e6a6d0..0000000000000000000000000000000000000000
--- a/code/xtuner/version.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-__version__ = '0.1.23'
-short_version = __version__
-
-
-def parse_version_info(version_str):
- """Parse a version string into a tuple.
-
- Args:
- version_str (str): The version string.
- Returns:
- tuple[int or str]: The version info, e.g., "1.3.0" is parsed into
- (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
- """
- version_info = []
- for x in version_str.split('.'):
- if x.isdigit():
- version_info.append(int(x))
- elif x.find('rc') != -1:
- patch_version = x.split('rc')
- version_info.append(int(patch_version[0]))
- version_info.append(f'rc{patch_version[1]}')
- return tuple(version_info)
-
-
-version_info = parse_version_info(__version__)