|
|
"""Shared utilities used across core and adapters. |
|
|
|
|
|
Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding, |
|
|
parameter grouping, model copying, etc.). Keep this file dependency-light. |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple |
|
|
|
|
|
import copy |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def as_like(x: torch.Tensor, val) -> torch.Tensor: |
|
|
"""Create a scalar/tensor constant on same device/dtype as `x`.""" |
|
|
return torch.as_tensor(val, device=x.device, dtype=x.dtype) |
|
|
|
|
|
|
|
|
def first_param(module: nn.Module) -> torch.Tensor: |
|
|
for p in module.parameters(recurse=True): |
|
|
return p |
|
|
return torch.tensor(0.0) |
|
|
|
|
|
|
|
|
def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: |
|
|
return x.to(device=ref.device, dtype=ref.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int = 42, deterministic: bool = False) -> None: |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
if deterministic: |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def freeze(module: nn.Module) -> None: |
|
|
for p in module.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
|
|
|
def unfreeze(module: nn.Module) -> None: |
|
|
for p in module.parameters(): |
|
|
p.requires_grad_(True) |
|
|
|
|
|
|
|
|
def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int: |
|
|
if trainable_only: |
|
|
return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
|
return sum(p.numel() for p in module.parameters()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def input_spec_vision(sample) -> Tuple[int, int, int]: |
|
|
"""Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W).""" |
|
|
if isinstance(sample, torch.Tensor): |
|
|
B, C, H, W = sample.shape |
|
|
return int(B), int(H), int(W) |
|
|
if isinstance(sample, (tuple, list)) and len(sample) == 4: |
|
|
B, C, H, W = sample |
|
|
return int(B), int(H), int(W) |
|
|
raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def round_down_multiple(n: int, m: int) -> int: |
|
|
if m is None or m <= 1: |
|
|
return max(1, int(n)) |
|
|
n = int(n) |
|
|
return max(m, (n // m) * m) |
|
|
|
|
|
|
|
|
def clamp_int(v: int, lo: int, hi: int) -> int: |
|
|
return max(lo, min(int(v), hi)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear: |
|
|
W = mat.weight.detach() |
|
|
b = mat.bias.detach() if mat.bias is not None else None |
|
|
if keep_out is not None: |
|
|
idx_out = torch.as_tensor(keep_out, device=W.device) |
|
|
W = W.index_select(0, idx_out) |
|
|
if b is not None: |
|
|
b = b.index_select(0, idx_out) |
|
|
if keep_in is not None: |
|
|
idx_in = torch.as_tensor(keep_in, device=W.device) |
|
|
W = W.index_select(1, idx_in) |
|
|
out_f, in_f = W.shape |
|
|
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device) |
|
|
new.weight.copy_(W) |
|
|
if b is not None: |
|
|
new.bias.copy_(b) |
|
|
return new |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deepcopy_eval_cpu(module: nn.Module) -> nn.Module: |
|
|
m = copy.deepcopy(module).cpu().eval() |
|
|
return m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def zero_if_any(params: Iterable[torch.Tensor]) -> None: |
|
|
for p in params: |
|
|
if p.grad is not None: |
|
|
p.grad = None |
|
|
|
|
|
|
|
|
def any_grad(params: Iterable[torch.Tensor]) -> bool: |
|
|
for p in params: |
|
|
if p.grad is not None: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module: |
|
|
""" |
|
|
Rebuild all parameters as fresh nn.Parameter tensors (detach+clone), |
|
|
which drops any 'inference tensor' tag and re-enables autograd. |
|
|
""" |
|
|
for mod in module.modules(): |
|
|
for name, p in list(mod._parameters.items()): |
|
|
if p is None: |
|
|
continue |
|
|
new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad) |
|
|
setattr(mod, name, new_p) |
|
|
return module |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExportRounding: |
|
|
head_floor_post: int = 1 |
|
|
head_multiple_post: int = 1 |
|
|
ffn_min_keep_ratio_post: float = 0.0 |
|
|
ffn_snap_groups_post: int = 1 |
|
|
|
|
|
|
|
|
def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple: |
|
|
B, C, H, W = sample_shape |
|
|
return ( |
|
|
"ViT", |
|
|
sample_shape, |
|
|
int(getattr(cfg, "num_attention_heads", 12)), |
|
|
int(getattr(cfg, "hidden_size", 768)), |
|
|
int(getattr(cfg, "intermediate_size", 3072)), |
|
|
int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))), |
|
|
) |
|
|
|