miosipov's picture
Upload folder using huggingface_hub
70b8d48 verified
"""Shared utilities used across core and adapters.
Consolidates helpers that are generic (device/dtype, seeding, shapes, rounding,
parameter grouping, model copying, etc.). Keep this file dependency-light.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import copy
import random
import numpy as np
import torch
import torch.nn as nn
# -----------------------------------------------------------------------------
# Device / dtype helpers
# -----------------------------------------------------------------------------
def as_like(x: torch.Tensor, val) -> torch.Tensor:
"""Create a scalar/tensor constant on same device/dtype as `x`."""
return torch.as_tensor(val, device=x.device, dtype=x.dtype)
def first_param(module: nn.Module) -> torch.Tensor:
for p in module.parameters(recurse=True):
return p
return torch.tensor(0.0)
def to_device_dtype(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
return x.to(device=ref.device, dtype=ref.dtype)
# -----------------------------------------------------------------------------
# Seeding & determinism
# -----------------------------------------------------------------------------
def set_seed(seed: int = 42, deterministic: bool = False) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# -----------------------------------------------------------------------------
# Model parameter helpers
# -----------------------------------------------------------------------------
def freeze(module: nn.Module) -> None:
for p in module.parameters():
p.requires_grad_(False)
def unfreeze(module: nn.Module) -> None:
for p in module.parameters():
p.requires_grad_(True)
def count_parameters(module: nn.Module, *, trainable_only: bool = False) -> int:
if trainable_only:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
return sum(p.numel() for p in module.parameters())
# -----------------------------------------------------------------------------
# Shape/signature helpers
# -----------------------------------------------------------------------------
def input_spec_vision(sample) -> Tuple[int, int, int]:
"""Accept either a 4D tensor [B,3,H,W] or a 4-tuple (B,3,H,W). Returns (B,H,W)."""
if isinstance(sample, torch.Tensor):
B, C, H, W = sample.shape
return int(B), int(H), int(W)
if isinstance(sample, (tuple, list)) and len(sample) == 4:
B, C, H, W = sample
return int(B), int(H), int(W)
raise ValueError("sample must be a tensor [B,3,H,W] or a 4-tuple (B,3,H,W)")
# -----------------------------------------------------------------------------
# Rounding / multiples
# -----------------------------------------------------------------------------
def round_down_multiple(n: int, m: int) -> int:
if m is None or m <= 1:
return max(1, int(n))
n = int(n)
return max(m, (n // m) * m)
def clamp_int(v: int, lo: int, hi: int) -> int:
return max(lo, min(int(v), hi))
# -----------------------------------------------------------------------------
# Slicing helpers
# -----------------------------------------------------------------------------
@torch.no_grad()
def slice_linear(mat: nn.Linear, keep_in: Optional[Sequence[int]] = None, keep_out: Optional[Sequence[int]] = None) -> nn.Linear:
W = mat.weight.detach()
b = mat.bias.detach() if mat.bias is not None else None
if keep_out is not None:
idx_out = torch.as_tensor(keep_out, device=W.device)
W = W.index_select(0, idx_out)
if b is not None:
b = b.index_select(0, idx_out)
if keep_in is not None:
idx_in = torch.as_tensor(keep_in, device=W.device)
W = W.index_select(1, idx_in)
out_f, in_f = W.shape
new = nn.Linear(in_f, out_f, bias=(b is not None)).to(W.device)
new.weight.copy_(W)
if b is not None:
new.bias.copy_(b)
return new
# -----------------------------------------------------------------------------
# Copying & detaching models
# -----------------------------------------------------------------------------
def deepcopy_eval_cpu(module: nn.Module) -> nn.Module:
m = copy.deepcopy(module).cpu().eval()
return m
# -----------------------------------------------------------------------------
# Gradient utilities
# -----------------------------------------------------------------------------
def zero_if_any(params: Iterable[torch.Tensor]) -> None:
for p in params:
if p.grad is not None:
p.grad = None
def any_grad(params: Iterable[torch.Tensor]) -> bool:
for p in params:
if p.grad is not None:
return True
return False
# -----------------------------------------------------------------------------
# For fine-tuning
# -----------------------------------------------------------------------------
def ensure_trainable_parameters(module: nn.Module, *, requires_grad: bool = True) -> nn.Module:
"""
Rebuild all parameters as fresh nn.Parameter tensors (detach+clone),
which drops any 'inference tensor' tag and re-enables autograd.
"""
for mod in module.modules():
for name, p in list(mod._parameters.items()):
if p is None:
continue
new_p = nn.Parameter(p.detach().clone(), requires_grad=requires_grad)
setattr(mod, name, new_p)
return module
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
@dataclass
class ExportRounding:
head_floor_post: int = 1
head_multiple_post: int = 1
ffn_min_keep_ratio_post: float = 0.0
ffn_snap_groups_post: int = 1
def shape_signature_vit(cfg, sample_shape: Tuple[int, int, int, int]) -> Tuple:
B, C, H, W = sample_shape
return (
"ViT",
sample_shape,
int(getattr(cfg, "num_attention_heads", 12)),
int(getattr(cfg, "hidden_size", 768)),
int(getattr(cfg, "intermediate_size", 3072)),
int(getattr(cfg, "patch_size", 16)) if not isinstance(getattr(cfg, "patch_size", 16), (tuple, list)) else tuple(getattr(cfg, "patch_size", (16, 16))),
)