WCNegentropy commited on
Commit
789fd7b
·
verified ·
1 Parent(s): 3658b97

Remove nested directory: BitTransformerLM/bit_transformer/scale.py

Browse files
BitTransformerLM/bit_transformer/scale.py DELETED
@@ -1,36 +0,0 @@
1
- import torch
2
- from typing import Dict
3
- from .model import BitTransformerLM
4
- import torch.nn as nn
5
-
6
-
7
- def expand_model(model: BitTransformerLM, new_params: Dict) -> BitTransformerLM:
8
- """Return a new model with updated params and copied weights."""
9
- new_model = BitTransformerLM(**new_params)
10
- new_state = new_model.state_dict()
11
- old_state = model.state_dict()
12
-
13
- for k, v in old_state.items():
14
- if k in new_state:
15
- dest = new_state[k]
16
- slices = tuple(slice(0, min(d, s)) for d, s in zip(dest.shape, v.shape))
17
- dest[slices].copy_(v[slices])
18
- if dest.shape != v.shape:
19
- mask = torch.ones_like(dest, dtype=torch.bool)
20
- mask[slices] = False
21
- if "bias" in k:
22
- dest[mask] = 0.0
23
- else:
24
- dest[mask] = 0.001 * torch.randn_like(dest[mask])
25
-
26
- for k, v in new_state.items():
27
- if k not in old_state:
28
- if "bias" in k:
29
- v.zero_()
30
- elif v.dim() > 1:
31
- nn.init.normal_(v, mean=0.0, std=1e-3)
32
- else:
33
- v.zero_()
34
-
35
- new_model.load_state_dict(new_state)
36
- return new_model