wujun commited on
Commit
8fd0890
·
1 Parent(s): 18d999c

corrct rf8 function

Browse files
Files changed (1) hide show
  1. rf8.py +26 -1
rf8.py CHANGED
@@ -17,4 +17,29 @@ def rf8(model, n=4):
17
  data1, exps1 = get_residual(param.data)
18
  data2, exps2 = get_residual(param.data - data1)
19
  flags = (exps1-exps2 <= n)
20
- param.data = data1 + flags * data2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  data1, exps1 = get_residual(param.data)
18
  data2, exps2 = get_residual(param.data - data1)
19
  flags = (exps1-exps2 <= n)
20
+ param.data = data1 + flags * data2
21
+
22
+
23
+ def rf8_new(model):
24
+ """8-bit Residual Float-pointing Format"""
25
+ with torch.no_grad():
26
+ for param in model.parameters():
27
+ param_ = param.cpu()
28
+
29
+ signs, exps = torch.sign(param_), torch.frexp(param_)[1] - 1
30
+
31
+ bias = torch.tensor([-4, -3, -2, 1, 0], dtype=int)
32
+ exps_ = exps.unsqueeze(-1).expand(*exps.shape, 5)
33
+ Exponents = torch.exp2(exps)
34
+
35
+ res_list = torch.exp2(bias + exps_)
36
+ res_true = torch.abs(param_) - Exponents
37
+ res_true = res_true.unsqueeze(-1).expand(*res_true.shape, 5)
38
+
39
+ indices = (res_true - res_list).abs().argmin(-1).unsqueeze(-1)
40
+ Residuals = torch.gather(res_list, -1, indices).squeeze()
41
+
42
+ values = signs * (Exponents + Residuals)
43
+ values[values.abs() < 2**-12] = 0
44
+ values[values.abs() > 2**5] = 0
45
+ param.data = values.to(torch.bfloat16).to(param.device)