WCNegentropy commited on
Commit
2bb2eda
·
verified ·
1 Parent(s): 060d6ba

Remove nested directory: BitTransformerLM/bit_transformer/distil.py

Browse files
BitTransformerLM/bit_transformer/distil.py DELETED
@@ -1,90 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import Optional
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
- from .model import BitTransformerLM
10
-
11
-
12
- @dataclass
13
- class TelemetryLog:
14
- """Telemetry container holding attention maps across steps.
15
-
16
- Attributes:
17
- attention_maps: Tensor of shape [steps, heads, seq, seq].
18
- """
19
-
20
- attention_maps: torch.Tensor
21
-
22
-
23
- def distill_step(model: BitTransformerLM, scale: float, telemetry: TelemetryLog) -> BitTransformerLM:
24
- """Return a pruned copy of ``model`` according to attention telemetry.
25
-
26
- Args:
27
- model: Teacher model to distill from.
28
- scale: Fraction of weights to retain (0 < scale <= 1).
29
- telemetry: Logged attention maps used to estimate parameter importance.
30
-
31
- This function computes an importance score for each weight in the model's
32
- linear layers using the supplied attention maps. The score is the mean
33
- activation over time multiplied by the number of visits (non-zero
34
- attention). The bottom ``(1 - scale)`` fraction of weights in each layer are
35
- zeroed out, yielding a sparsified student model.
36
- """
37
- if not (0.0 < scale <= 1.0):
38
- raise ValueError("scale must lie in (0, 1].")
39
-
40
- # Clone the model so the teacher remains untouched.
41
- student = BitTransformerLM(
42
- d_model=model.d_model,
43
- nhead=model.layers[0].self_attn.num_heads,
44
- num_layers=model.num_layers,
45
- dim_feedforward=model.layers[0].linear1.out_features,
46
- max_seq_len=model.pos_enc.pe.size(0),
47
- lambda_K=model.lambda_K,
48
- lambda_C=model.lambda_C,
49
- lambda_S=model.lambda_S,
50
- reversible=model.reversible,
51
- use_checkpoint=model.use_checkpoint,
52
- use_autocast=model.use_autocast,
53
- use_act=model.use_act,
54
- act_threshold=model.act_threshold,
55
- chunk_size=model.chunk_size,
56
- overlap=model.overlap,
57
- )
58
- student.load_state_dict(model.state_dict())
59
-
60
- attn = telemetry.attention_maps # [steps, heads, seq, seq]
61
- steps = attn.shape[0]
62
- heads = attn.shape[1]
63
- mean_act = attn.mean(dim=(0, 2, 3))
64
- visits = (attn > 0).sum(dim=(0, 2, 3)).clamp_min(1)
65
- head_importance = mean_act * visits
66
- head_importance = head_importance / head_importance.sum()
67
-
68
- prune_frac = 1.0 - scale
69
-
70
- for module in student.modules():
71
- if isinstance(module, nn.Linear):
72
- weight = module.weight.data
73
- out_features = weight.size(0)
74
- if out_features % heads == 0:
75
- repeats = out_features // heads
76
- row_scores = head_importance.repeat_interleave(repeats).view(out_features, 1)
77
- else:
78
- row_scores = head_importance.mean().expand(out_features, 1)
79
-
80
- importance = weight.abs() * row_scores
81
- k = int(importance.numel() * prune_frac)
82
- if k > 0:
83
- thresh = torch.topk(importance.view(-1), k, largest=False).values.max()
84
- mask = importance > thresh
85
- weight.mul_(mask)
86
- if module.bias is not None:
87
- row_mask = mask.view(out_features, -1).any(dim=1)
88
- module.bias.data.mul_(row_mask)
89
-
90
- return student