File size: 6,507 Bytes
70b8d48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """Shared utilities used across core and adapters.
Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
parameter grouping, model copying, etc.). Keep this file dependency-light.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import copy
import random
import numpy as np
import torch
import torch.nn as nn
# -----------------------------------------------------------------------------
# Device / dtype helpers
# -----------------------------------------------------------------------------
def as_like(x: torch.Tensor, val) -> torch.Tensor:
"""Create a scalar/tensor constant on same device/dtype as `x`."""
return torch.as_tensor(val, device=x.device, dtype=x.dtype)
def first_param(module: nn.Module) -> torch.Tensor:
for p in module.parameters(recurse=True):
return p
return torch.tensor(0.0)
def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return x.to(device=ref.device, dtype=ref.dtype)
# -----------------------------------------------------------------------------
# Seeding & determinism
# -----------------------------------------------------------------------------
def set_seed(seed: int = 42, deterministic: bool = False) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# -----------------------------------------------------------------------------
# Model parameter helpers
# -----------------------------------------------------------------------------
def freeze(module: nn.Module) -> None:
for p in module.parameters():
p.requires_grad_(False)
def unfreeze(module: nn.Module) -> None:
for p in module.parameters():
p.requires_grad_(True)
def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
if trainable_only:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
return sum(p.numel() for p in module.parameters())
# -----------------------------------------------------------------------------
# Shape/signature helpers
# -----------------------------------------------------------------------------
def input_spec_vision(sample) -> Tuple[int, int, int]:
"""Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
if isinstance(sample, torch.Tensor):
B, C, H, W = sample.shape
return int(B), int(H), int(W)
if isinstance(sample, (tuple, list)) and len(sample) == 4:
B, C, H, W = sample
return int(B), int(H), int(W)
raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
# -----------------------------------------------------------------------------
# Rounding / multiples
# -----------------------------------------------------------------------------
def round_down_multiple(n: int, m: int) -> int:
if m is None or m <= 1:
return max(1, int(n))
n = int(n)
return max(m, (n // m) * m)
def clamp_int(v: int, lo: int, hi: int) -> int:
return max(lo, min(int(v), hi))
# -----------------------------------------------------------------------------
# Slicing helpers
# -----------------------------------------------------------------------------
@torch.no_grad()
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
W = mat.weight.detach()
b = mat.bias.detach() if mat.bias is not None else None
if keep_out is not None:
idx_out = torch.as_tensor(keep_out, device=W.device)
W = W.index_select(0, idx_out)
if b is not None:
b = b.index_select(0, idx_out)
if keep_in is not None:
idx_in = torch.as_tensor(keep_in, device=W.device)
W = W.index_select(1, idx_in)
out_f, in_f = W.shape
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
# -----------------------------------------------------------------------------
# Copying & detaching models
# -----------------------------------------------------------------------------
def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
m = copy.deepcopy(module).cpu().eval()
return m
# -----------------------------------------------------------------------------
# Gradient utilities
# -----------------------------------------------------------------------------
def zero_if_any(params: Iterable[torch.Tensor]) -> None:
for p in params:
if p.grad is not None:
p.grad = None
def any_grad(params: Iterable[torch.Tensor]) -> bool:
for p in params:
if p.grad is not None:
return True
return False
# -----------------------------------------------------------------------------
# For fine-tuning
# -----------------------------------------------------------------------------
def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
"""
Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
which drops any 'inference tensor' tag and re-enables autograd.
"""
for mod in module.modules():
for name, p in list(mod._parameters.items()):
if p is None:
continue
new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
setattr(mod, name, new_p)
return module
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
@dataclass
class ExportRounding:
head_floor_post: int = 1
head_multiple_post: int = 1
ffn_min_keep_ratio_post: float = 0.0
ffn_snap_groups_post: int = 1
def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
B, C, H, W = sample_shape
return (
"ViT",
sample_shape,
int(getattr(cfg, "num_attention_heads", 12)),
int(getattr(cfg, "hidden_size", 768)),
int(getattr(cfg, "intermediate_size", 3072)),
int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
)
|