| |
| |
| |
|
|
| |
|
|
| package main |
|
|
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "log" |
| "os" |
| "slices" |
| "strconv" |
|
|
| "internal/runtime/gc" |
| "internal/runtime/gc/internal/gen" |
| ) |
|
|
| const header = "// Code generated by mkasm.go. DO NOT EDIT.\n\n" |
|
|
| func main() { |
| generate("expand_amd64.s", genExpanders) |
| } |
|
|
| func generate(fileName string, genFunc func(*gen.File)) { |
| var buf bytes.Buffer |
| tee := io.MultiWriter(&buf, os.Stdout) |
|
|
| file := gen.NewFile(tee) |
|
|
| genFunc(file) |
|
|
| fmt.Fprintf(tee, header) |
| file.Compile() |
|
|
| f, err := os.Create(fileName) |
| if err != nil { |
| log.Fatal(err) |
| } |
| defer f.Close() |
| _, err = f.Write(buf.Bytes()) |
| if err != nil { |
| log.Fatal(err) |
| } |
| } |
|
|
| func genExpanders(file *gen.File) { |
| gcExpandersAVX512 := make([]*gen.Func, len(gc.SizeClassToSize)) |
| for sc, ob := range gc.SizeClassToSize { |
| if gc.SizeClassToNPages[sc] != 1 { |
| |
| |
| continue |
| } |
| if ob > gc.MinSizeForMallocHeader { |
| |
| break |
| } |
|
|
| xf := int(ob) / 8 |
| log.Printf("size class %d bytes, expansion %dx", ob, xf) |
|
|
| fn := gen.NewFunc(fmt.Sprintf("expandAVX512_%d<>", xf)) |
| ptrObjBits := gen.Arg[gen.Ptr[gen.Uint8x64]](fn) |
|
|
| if xf == 1 { |
| expandIdentity(ptrObjBits) |
| } else { |
| ok := gfExpander(xf, ptrObjBits) |
| if !ok { |
| log.Printf("failed to generate expander for size class %d", sc) |
| } |
| } |
| file.AddFunc(fn) |
| gcExpandersAVX512[sc] = fn |
| } |
|
|
| |
| file.AddConst("·gcExpandersAVX512", gcExpandersAVX512) |
| } |
|
|
| |
| type mat8x8 struct { |
| mat [8]uint8 |
| } |
|
|
| func matGroupToVec(mats *[8]mat8x8) [8]uint64 { |
| var out [8]uint64 |
| for i, mat := range mats { |
| for j, row := range mat.mat { |
| |
| out[i] |= uint64(row) << ((7 - j) * 8) |
| } |
| } |
| return out |
| } |
|
|
| |
| func expandIdentity(ptrObjBits gen.Ptr[gen.Uint8x64]) { |
| objBitsLo := gen.Deref(ptrObjBits) |
| objBitsHi := gen.Deref(ptrObjBits.AddConst(64)) |
| gen.Return(objBitsLo, objBitsHi) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| func gfExpander(f int, ptrObjBits gen.Ptr[gen.Uint8x64]) bool { |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| objBits := gen.Deref(ptrObjBits) |
|
|
| type term struct { |
| iByte, oByte int |
| mat mat8x8 |
| } |
| var terms []term |
|
|
| |
| |
| |
| for oByte := 0; oByte < 1024/8; oByte++ { |
| var byteMat mat8x8 |
| iByte := -1 |
| for oBit := oByte * 8; oBit < oByte*8+8; oBit++ { |
| iBit := oBit / f |
| if iByte == -1 { |
| iByte = iBit / 8 |
| } else if iByte != iBit/8 { |
| log.Printf("output byte %d straddles input bytes %d and %d", oByte, iByte, iBit/8) |
| return false |
| } |
| |
| |
| |
| |
| byteMat.mat[oBit%8] = 1 << (iBit % 8) |
| } |
| terms = append(terms, term{iByte, oByte, byteMat}) |
| } |
|
|
| if false { |
| |
| maxIByte, maxOByte := 0, 0 |
| for _, term := range terms { |
| maxIByte = max(maxIByte, term.iByte) |
| maxOByte = max(maxOByte, term.oByte) |
| } |
| iToO := make([][]rune, maxIByte+1) |
| for i := range iToO { |
| iToO[i] = make([]rune, maxOByte+1) |
| } |
| matMap := make(map[mat8x8]int) |
| for _, term := range terms { |
| i, ok := matMap[term.mat] |
| if !ok { |
| i = len(matMap) |
| matMap[term.mat] = i |
| } |
| iToO[term.iByte][term.oByte] = 'A' + rune(i) |
| } |
| for o := range maxOByte + 1 { |
| fmt.Printf("%d", o) |
| for i := range maxIByte + 1 { |
| fmt.Printf(",") |
| if mat := iToO[i][o]; mat != 0 { |
| fmt.Printf("%c", mat) |
| } |
| } |
| fmt.Println() |
| } |
| } |
|
|
| |
| |
| |
| |
| |
|
|
| |
| const termsPerGroup = 8 |
| const groupsPerSuperGroup = 8 |
|
|
| matMap := make(map[mat8x8]int) |
| allMats := make(map[mat8x8]bool) |
| var termGroups [][]term |
| for _, term := range terms { |
| allMats[term.mat] = true |
|
|
| i, ok := matMap[term.mat] |
| if ok && f > groupsPerSuperGroup { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| outRegister := termGroups[i][0].oByte / 64 |
| if term.oByte/64 != outRegister { |
| ok = false |
| } |
| } |
| if !ok { |
| |
| i = len(termGroups) |
| matMap[term.mat] = i |
| termGroups = append(termGroups, nil) |
| } |
|
|
| termGroups[i] = append(termGroups[i], term) |
|
|
| if len(termGroups[i]) == termsPerGroup { |
| |
| delete(matMap, term.mat) |
| } |
| } |
|
|
| for i, termGroup := range termGroups { |
| log.Printf("term group %d:", i) |
| for _, term := range termGroup { |
| log.Printf(" %+v", term) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| var sgSize, nSuperGroups int |
| oneMatVec := f <= groupsPerSuperGroup |
| if oneMatVec { |
| |
| |
| sgSize = groupsPerSuperGroup / len(allMats) * len(allMats) |
| nSuperGroups = (len(termGroups) + sgSize - 1) / sgSize |
| } else { |
| |
| |
| |
| |
| |
| |
| |
| sgSize = 8 |
| nSuperGroups = (len(termGroups) + groupsPerSuperGroup - 1) / groupsPerSuperGroup |
| } |
|
|
| |
| var matGroup [8]mat8x8 |
| var matMuls []gen.Uint8x64 |
| var perm [128]int |
| for sgi := range nSuperGroups { |
| var iperm [64]uint8 |
| for i := range iperm { |
| iperm[i] = 0xff |
| } |
| |
| superGroup := termGroups[:min(len(termGroups), sgSize)] |
| termGroups = termGroups[len(superGroup):] |
| |
| var thisMatGroup [8]mat8x8 |
| for i, termGroup := range superGroup { |
| |
| thisMatGroup[i] = termGroup[0].mat |
| for j, term := range termGroup { |
| |
| iperm[i*termsPerGroup+j] = uint8(term.iByte) |
| |
| perm[term.oByte] = sgi*groupsPerSuperGroup*termsPerGroup + i*termsPerGroup + j |
| } |
| } |
| log.Printf("input permutation %d: %v", sgi, iperm) |
|
|
| |
| if oneMatVec { |
| if sgi == 0 { |
| matGroup = thisMatGroup |
| } else if matGroup != thisMatGroup { |
| log.Printf("super-groups have different matrixes:\n%+v\n%+v", matGroup, thisMatGroup) |
| return false |
| } |
| } |
|
|
| |
| matConst := gen.ConstUint64x8(matGroupToVec(&thisMatGroup), fmt.Sprintf("*_mat%d<>", sgi)) |
| inOp := objBits.Shuffle(gen.ConstUint8x64(iperm, fmt.Sprintf("*_inShuf%d<>", sgi))) |
| matMul := matConst.GF2P8Affine(inOp) |
| matMuls = append(matMuls, matMul) |
| } |
|
|
| log.Printf("output permutation: %v", perm) |
|
|
| outLo, ok := genShuffle("*_outShufLo", (*[64]int)(perm[:64]), matMuls...) |
| if !ok { |
| log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) |
| return false |
| } |
| outHi, ok := genShuffle("*_outShufHi", (*[64]int)(perm[64:]), matMuls...) |
| if !ok { |
| log.Printf("bad number of inputs to final shuffle: %d != 1, 2, or 4", len(matMuls)) |
| return false |
| } |
| gen.Return(outLo, outHi) |
|
|
| return true |
| } |
|
|
| func genShuffle(name string, perm *[64]int, args ...gen.Uint8x64) (gen.Uint8x64, bool) { |
| |
| var vperm [64]byte |
|
|
| |
| var inputs []int |
| for i, src := range perm { |
| inputIdx := slices.Index(inputs, src/64) |
| if inputIdx == -1 { |
| inputIdx = len(inputs) |
| inputs = append(inputs, src/64) |
| } |
| vperm[i] = byte(src%64 | (inputIdx << 6)) |
| } |
|
|
| |
| switch len(inputs) { |
| case 1: |
| constOp := gen.ConstUint8x64(vperm, name) |
| return args[inputs[0]].Shuffle(constOp), true |
| case 2: |
| constOp := gen.ConstUint8x64(vperm, name) |
| return args[inputs[0]].Shuffle2(args[inputs[1]], constOp), true |
| } |
|
|
| |
| |
| |
| |
| |
| |
| var vperms [2][64]byte |
| var masks [2]uint64 |
| for j, idx := range vperm { |
| for i := range vperms { |
| vperms[i][j] = 0xff |
| } |
| if idx == 0xff { |
| continue |
| } |
| vperms[idx/128][j] = idx % 128 |
| masks[idx/128] |= uint64(1) << j |
| } |
|
|
| |
| if masks[0]^masks[1] != ^uint64(0) { |
| panic("bad shuffle!") |
| } |
|
|
| |
| constOps := make([]gen.Uint8x64, len(vperms)) |
| for i, v := range vperms { |
| constOps[i] = gen.ConstUint8x64(v, name+strconv.Itoa(i)) |
| } |
|
|
| |
| switch len(inputs) { |
| case 3: |
| r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) |
| r1 := args[inputs[2]].ShuffleZeroed(constOps[1], gen.ConstMask64(masks[1])) |
| return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true |
| case 4: |
| r0 := args[inputs[0]].Shuffle2Zeroed(args[inputs[1]], constOps[0], gen.ConstMask64(masks[0])) |
| r1 := args[inputs[2]].Shuffle2Zeroed(args[inputs[3]], constOps[1], gen.ConstMask64(masks[1])) |
| return r0.ToUint64x8().Or(r1.ToUint64x8()).ToUint8x64(), true |
| } |
|
|
| |
| |
| return args[0], false |
| } |
|
|