File size: 3,941 Bytes
e36aeda | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | // Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fips140cache
import (
"context"
"errors"
"runtime"
"sync"
"testing"
"time"
)
func TestCache(t *testing.T) {
c := new(Cache[key, value])
checkTrue := func(*value) bool { return true }
checkFalse := func(*value) bool { return false }
newNotCalled := func() (*value, error) {
t.Helper()
t.Fatal("new called")
return nil, nil
}
k1 := newKey()
v1 := &value{}
v, err := c.Get(k1, func() (*value, error) { return v1, nil }, checkTrue)
expectValue(t, v, err, v1)
// Cached value is returned if check is true.
v, err = c.Get(k1, newNotCalled, checkTrue)
expectValue(t, v, err, v1)
// New value is returned and cached if check is false.
v2 := &value{}
v, err = c.Get(k1, func() (*value, error) { return v2, nil }, checkFalse)
expectValue(t, v, err, v2)
v, err = c.Get(k1, newNotCalled, checkTrue)
expectValue(t, v, err, v2)
expectMapSize(t, c, 1)
// Cache is evicted when key becomes unreachable.
waitUnreachable(t, &k1)
expectMapSize(t, c, 0)
// Value is not cached if new returns an error.
k2 := newKey()
err1 := errors.New("error")
_, err = c.Get(k2, func() (*value, error) { return nil, err1 }, checkTrue)
if err != err1 {
t.Errorf("got %v, want %v", err, err1)
}
expectMapSize(t, c, 0)
// Value is not replaced if check is false and new returns an error.
v, err = c.Get(k2, func() (*value, error) { return v1, nil }, checkTrue)
expectValue(t, v, err, v1)
_, err = c.Get(k2, func() (*value, error) { return v2, err1 }, checkFalse)
if err != err1 {
t.Errorf("got %v, want %v", err, err1)
}
v, err = c.Get(k2, newNotCalled, checkTrue)
expectValue(t, v, err, v1)
expectMapSize(t, c, 1)
// Cache is evicted for keys used only once.
k3 := newKey()
v, err = c.Get(k3, func() (*value, error) { return v1, nil }, checkTrue)
expectValue(t, v, err, v1)
expectMapSize(t, c, 2)
waitUnreachable(t, &k2)
waitUnreachable(t, &k3)
expectMapSize(t, c, 0)
// When two goroutines race, the returned value may be the new or old one,
// but the map must shrink to 0.
keys := make([]*key, 100)
for i := range keys {
keys[i] = newKey()
v1, v2 := &value{}, &value{}
start := make(chan struct{})
var wg sync.WaitGroup
wg.Add(2)
go func() {
<-start
v, err := c.Get(keys[i], func() (*value, error) { return v1, nil }, checkTrue)
expectValue(t, v, err, v1, v2)
wg.Done()
}()
go func() {
<-start
v, err := c.Get(keys[i], func() (*value, error) { return v2, nil }, checkTrue)
expectValue(t, v, err, v1, v2)
wg.Done()
}()
close(start)
wg.Wait()
v3 := &value{}
v, err := c.Get(keys[i], func() (*value, error) { return v3, nil }, checkTrue)
expectValue(t, v, err, v1, v2)
}
for i := range keys {
waitUnreachable(t, &keys[i])
}
expectMapSize(t, c, 0)
}
type key struct {
_ *int
}
type value struct {
_ *int
}
// newKey allocates a key value on the heap.
//
//go:noinline
func newKey() *key {
return &key{}
}
func expectValue(t *testing.T, v *value, err error, want ...*value) {
t.Helper()
if err != nil {
t.Fatal(err)
}
for _, w := range want {
if v == w {
return
}
}
t.Errorf("got %p, want %p", v, want)
}
func expectMapSize(t *testing.T, c *Cache[key, value], want int) {
t.Helper()
var size int
// Loop a few times because the AddCleanup might not be done yet.
for range 10 {
size = 0
c.m.Range(func(_, _ any) bool {
size++
return true
})
if size == want {
return
}
time.Sleep(100 * time.Millisecond)
}
t.Errorf("got %d, want %d", size, want)
}
func waitUnreachable(t *testing.T, k **key) {
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
runtime.AddCleanup(*k, func(_ *int) { cancel() }, nil)
*k = nil
for ctx.Err() == nil {
runtime.GC()
}
if ctx.Err() != context.Canceled {
t.Fatal(ctx.Err())
}
}
|