|
|
package utils |
|
|
|
|
|
import ( |
|
|
"crypto/tls" |
|
|
"crypto/x509" |
|
|
"fmt" |
|
|
"net" |
|
|
"os" |
|
|
"strings" |
|
|
"sync" |
|
|
"sync/atomic" |
|
|
"time" |
|
|
) |
|
|
|
|
|
type LocalCertificateLoader struct { |
|
|
CertFile string |
|
|
KeyFile string |
|
|
SNIGuard SNIGuardFunc |
|
|
|
|
|
lock sync.Mutex |
|
|
cache atomic.Pointer[localCertificateCache] |
|
|
} |
|
|
|
|
|
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type localCertificateCache struct { |
|
|
certificate *tls.Certificate |
|
|
certModTime time.Time |
|
|
keyModTime time.Time |
|
|
} |
|
|
|
|
|
func (l *LocalCertificateLoader) InitializeCache() error { |
|
|
l.lock.Lock() |
|
|
defer l.lock.Unlock() |
|
|
|
|
|
cache, err := l.makeCache() |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
|
|
|
l.cache.Store(cache) |
|
|
return nil |
|
|
} |
|
|
|
|
|
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
|
|
cert, err := l.getCertificateWithCache() |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
if l.SNIGuard == nil { |
|
|
return cert, nil |
|
|
} |
|
|
err = l.SNIGuard(info, cert) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
return cert, nil |
|
|
} |
|
|
|
|
|
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { |
|
|
fi, err := os.Stat(l.CertFile) |
|
|
if err != nil { |
|
|
err = fmt.Errorf("failed to stat certificate file: %w", err) |
|
|
return |
|
|
} |
|
|
certModTime = fi.ModTime() |
|
|
|
|
|
fi, err = os.Stat(l.KeyFile) |
|
|
if err != nil { |
|
|
err = fmt.Errorf("failed to stat key file: %w", err) |
|
|
return |
|
|
} |
|
|
keyModTime = fi.ModTime() |
|
|
return |
|
|
} |
|
|
|
|
|
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) { |
|
|
c := &localCertificateCache{} |
|
|
|
|
|
c.certModTime, c.keyModTime, err = l.checkModTime() |
|
|
if err != nil { |
|
|
return |
|
|
} |
|
|
|
|
|
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile) |
|
|
if err != nil { |
|
|
return |
|
|
} |
|
|
c.certificate = &cert |
|
|
if c.certificate.Leaf == nil { |
|
|
|
|
|
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) |
|
|
if err != nil { |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
cache = c |
|
|
return |
|
|
} |
|
|
|
|
|
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { |
|
|
cache := l.cache.Load() |
|
|
|
|
|
certModTime, keyModTime, terr := l.checkModTime() |
|
|
if terr != nil { |
|
|
if cache != nil { |
|
|
|
|
|
return cache.certificate, nil |
|
|
} |
|
|
return nil, terr |
|
|
} |
|
|
|
|
|
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) { |
|
|
|
|
|
return cache.certificate, nil |
|
|
} |
|
|
|
|
|
if cache != nil { |
|
|
if !l.lock.TryLock() { |
|
|
|
|
|
return cache.certificate, nil |
|
|
} |
|
|
} else { |
|
|
l.lock.Lock() |
|
|
} |
|
|
defer l.lock.Unlock() |
|
|
|
|
|
if l.cache.Load() != cache { |
|
|
|
|
|
return l.cache.Load().certificate, nil |
|
|
} |
|
|
|
|
|
newCache, err := l.makeCache() |
|
|
if err != nil { |
|
|
if cache != nil { |
|
|
|
|
|
return cache.certificate, nil |
|
|
} |
|
|
return nil, err |
|
|
} |
|
|
|
|
|
l.cache.Store(newCache) |
|
|
return newCache.certificate, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func getNameFromClientHello(hello *tls.ClientHelloInfo) string { |
|
|
normalizedName := func(serverName string) string { |
|
|
return strings.ToLower(strings.TrimSpace(serverName)) |
|
|
} |
|
|
localIPFromConn := func(c net.Conn) string { |
|
|
if c == nil { |
|
|
return "" |
|
|
} |
|
|
localAddr := c.LocalAddr().String() |
|
|
ip, _, err := net.SplitHostPort(localAddr) |
|
|
if err != nil { |
|
|
ip = localAddr |
|
|
} |
|
|
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 { |
|
|
ip = ip[:scopeIDStart] |
|
|
} |
|
|
return ip |
|
|
} |
|
|
|
|
|
if name := normalizedName(hello.ServerName); name != "" { |
|
|
return name |
|
|
} |
|
|
return localIPFromConn(hello.Conn) |
|
|
} |
|
|
|
|
|
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error { |
|
|
if len(cert.Leaf.DNSNames) == 0 { |
|
|
return nil |
|
|
} |
|
|
return SNIGuardStrict(info, cert) |
|
|
} |
|
|
|
|
|
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error { |
|
|
hostname := getNameFromClientHello(info) |
|
|
err := cert.Leaf.VerifyHostname(hostname) |
|
|
if err != nil { |
|
|
return fmt.Errorf("sni guard: %w", err) |
|
|
} |
|
|
return nil |
|
|
} |
|
|
|