|
|
package server |
|
|
|
|
|
import ( |
|
|
"fmt" |
|
|
"io" |
|
|
"net" |
|
|
"os" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"go.uber.org/zap" |
|
|
"golang.org/x/crypto/ssh" |
|
|
) |
|
|
|
|
|
const ( |
|
|
minPort = 55000 |
|
|
maxPort = 65000 |
|
|
) |
|
|
|
|
|
|
|
|
type SSHServer struct { |
|
|
mu sync.Mutex |
|
|
opts *Options |
|
|
listener net.Listener |
|
|
config *ssh.ServerConfig |
|
|
running chan error |
|
|
isRunning bool |
|
|
clients map[string]*client |
|
|
addr string |
|
|
domain string |
|
|
logger *zap.SugaredLogger |
|
|
} |
|
|
|
|
|
type client struct { |
|
|
mu sync.Mutex |
|
|
id string |
|
|
tcpConn net.Conn |
|
|
sshConn *ssh.ServerConn |
|
|
ch ssh.Channel |
|
|
listeners map[string]net.Listener |
|
|
addr string |
|
|
port uint32 |
|
|
} |
|
|
|
|
|
func (c *client) write(data string) { |
|
|
if c.ch != nil { |
|
|
io.WriteString(c.ch, data) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func NewSSHServer(opts *Options, logger *zap.SugaredLogger) *SSHServer { |
|
|
return &SSHServer{ |
|
|
opts: opts, |
|
|
config: &ssh.ServerConfig{ |
|
|
NoClientAuth: true, |
|
|
}, |
|
|
running: make(chan error, 1), |
|
|
clients: make(map[string]*client), |
|
|
logger: logger, |
|
|
isRunning: true, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func (s *SSHServer) Run() error { |
|
|
privateKeyContent, err := os.ReadFile(s.opts.PrivateKey) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
private, err := ssh.ParsePrivateKey(privateKeyContent) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
s.config.AddHostKey(private) |
|
|
s.addr = s.opts.SSHAddr |
|
|
s.domain = s.opts.Domain |
|
|
|
|
|
go s.closeWith(s.listen()) |
|
|
return nil |
|
|
} |
|
|
|
|
|
|
|
|
func (s *SSHServer) Close() error { |
|
|
s.closeWith(nil) |
|
|
return s.listener.Close() |
|
|
} |
|
|
|
|
|
|
|
|
func (s *SSHServer) Wait() error { |
|
|
if !s.isRunning { |
|
|
return fmt.Errorf("already closed") |
|
|
} |
|
|
return <-s.running |
|
|
} |
|
|
|
|
|
func (s *SSHServer) closeWith(err error) { |
|
|
if !s.isRunning { |
|
|
return |
|
|
} |
|
|
s.isRunning = false |
|
|
s.running <- err |
|
|
} |
|
|
|
|
|
func (s *SSHServer) listen() error { |
|
|
listener, err := net.Listen("tcp", s.addr) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
s.listener = listener |
|
|
|
|
|
s.logger.Infof("starting SSH server on %s", s.addr) |
|
|
|
|
|
for { |
|
|
tcpConn, err := s.listener.Accept() |
|
|
if err != nil { |
|
|
s.logger.Errorf("failed to accept incoming connection: %v", err) |
|
|
continue |
|
|
} |
|
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.config) |
|
|
if err != nil { |
|
|
s.logger.Errorf("failed to handshake: %v", err) |
|
|
continue |
|
|
} |
|
|
|
|
|
genid: |
|
|
id := randID() |
|
|
if _, ok := s.clients[id]; ok { |
|
|
goto genid |
|
|
} |
|
|
|
|
|
c := &client{ |
|
|
id: id, |
|
|
tcpConn: tcpConn, |
|
|
sshConn: sshConn, |
|
|
listeners: make(map[string]net.Listener), |
|
|
addr: "", |
|
|
port: 0, |
|
|
} |
|
|
s.logger.Infof("new SSH connection from %s (%s)", sshConn.RemoteAddr().String(), sshConn.ClientVersion()) |
|
|
|
|
|
go func(c *client) { |
|
|
err := c.sshConn.Wait() |
|
|
s.logger.Infof("[%s] SSH connection closed: %v", c.id, err) |
|
|
|
|
|
c.mu.Lock() |
|
|
for bind, listener := range c.listeners { |
|
|
s.logger.Debugf("[%s] closing listener bound to %s", c.id, bind) |
|
|
listener.Close() |
|
|
} |
|
|
c.mu.Unlock() |
|
|
|
|
|
s.mu.Lock() |
|
|
delete(s.clients, c.id) |
|
|
s.mu.Unlock() |
|
|
}(c) |
|
|
|
|
|
go s.handleRequests(c, reqs) |
|
|
go s.handleChannels(c, chans) |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleChannels(client *client, chans <-chan ssh.NewChannel) { |
|
|
for nch := range chans { |
|
|
chconn, _, err := nch.Accept() |
|
|
if err != nil { |
|
|
s.logger.Errorf("[%s] could not accept channel: %v", client.id, err) |
|
|
return |
|
|
} |
|
|
client.ch = chconn |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleRequests(client *client, reqs <-chan *ssh.Request) { |
|
|
for req := range reqs { |
|
|
client.tcpConn.SetDeadline(time.Now().Add(2 * time.Minute)) |
|
|
|
|
|
if req.Type == "set-id" { |
|
|
var payload idRequestPayload |
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
s.logger.Errorf("[%s] Unable to unmarshal payload: %v", client.id, err) |
|
|
} |
|
|
if payload.ID != "" { |
|
|
if _, ok := s.clients[payload.ID]; !ok { |
|
|
s.mu.Lock() |
|
|
delete(s.clients, client.id) |
|
|
client.id = payload.ID |
|
|
s.clients[client.id] = client |
|
|
s.mu.Unlock() |
|
|
} |
|
|
} |
|
|
req.Reply(true, []byte{}) |
|
|
continue |
|
|
} |
|
|
|
|
|
if req.Type == "tcpip-forward" { |
|
|
listener, bindInfo, err := s.handleForward(client, req) |
|
|
if err != nil { |
|
|
s.logger.Errorf("[%s] error, disconnecting: %v", client.id, err) |
|
|
client.tcpConn.Close() |
|
|
continue |
|
|
} |
|
|
|
|
|
client.addr = bindInfo.Addr |
|
|
client.port = bindInfo.Port |
|
|
|
|
|
client.mu.Lock() |
|
|
client.listeners[bindInfo.Bound] = listener |
|
|
client.mu.Unlock() |
|
|
|
|
|
s.mu.Lock() |
|
|
s.clients[client.id] = client |
|
|
s.mu.Unlock() |
|
|
|
|
|
go s.handleListener(client, bindInfo, listener) |
|
|
|
|
|
if client.ch != nil { |
|
|
data := clientResponse{ |
|
|
id: client.id, |
|
|
domain: s.domain, |
|
|
port: client.port, |
|
|
} |
|
|
|
|
|
renderMessage(data, client.ch) |
|
|
renderTable(data, client.ch) |
|
|
} |
|
|
} else { |
|
|
req.Reply(false, []byte{}) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleListener(client *client, bindInfo *bindInfo, listener net.Listener) { |
|
|
for { |
|
|
conn, err := listener.Accept() |
|
|
if err != nil { |
|
|
neterr := err.(net.Error) |
|
|
if neterr.Timeout() { |
|
|
s.logger.Errorf("[%s] accept failed with timeout: %v", client.id, err) |
|
|
continue |
|
|
} |
|
|
if neterr.Temporary() { |
|
|
s.logger.Errorf("[%s] accept failed with temporary: %v", client.id, err) |
|
|
continue |
|
|
} |
|
|
|
|
|
break |
|
|
} |
|
|
|
|
|
go s.handleForwardTCPIP(client, bindInfo, conn) |
|
|
} |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleForwardTCPIP(client *client, bindInfo *bindInfo, conn net.Conn) { |
|
|
remoteAddr := conn.RemoteAddr().(*net.TCPAddr) |
|
|
raddr := remoteAddr.IP.String() |
|
|
rport := uint32(remoteAddr.Port) |
|
|
|
|
|
payload := forwardedTCPPayload{bindInfo.Addr, bindInfo.Port, raddr, rport} |
|
|
mpayload := ssh.Marshal(&payload) |
|
|
|
|
|
|
|
|
c, requests, err := client.sshConn.OpenChannel("forwarded-tcpip", mpayload) |
|
|
if err != nil { |
|
|
s.logger.Errorf("[%s] unable to get channel: %v. Hanging up requesting party!", client.id, err) |
|
|
conn.Close() |
|
|
return |
|
|
} |
|
|
s.logger.Debugf("[%s] channel opened for client %s:%d <-> %s", client.id, bindInfo.Addr, bindInfo.Port, remoteAddr.String()) |
|
|
go ssh.DiscardRequests(requests) |
|
|
go s.handleForwardTCPIPTransfer(c, conn) |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleForward(client *client, req *ssh.Request) (net.Listener, *bindInfo, error) { |
|
|
var payload tcpIPForwardPayload |
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err != nil { |
|
|
s.logger.Errorf("[%s] unable to unmarshal payload: %v", client.id, err) |
|
|
req.Reply(false, []byte{}) |
|
|
return nil, nil, fmt.Errorf("unable to parse payload") |
|
|
} |
|
|
|
|
|
s.logger.Debugf("[%s] request: %s %v %v", client.id, req.Type, req.WantReply, payload) |
|
|
|
|
|
listen: |
|
|
bind := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) |
|
|
if payload.Port == 0 { |
|
|
bind = fmt.Sprintf("%s:%d", payload.Addr, randomPort(minPort, maxPort)) |
|
|
} |
|
|
|
|
|
ln, err := net.Listen("tcp", bind) |
|
|
if err != nil { |
|
|
if payload.Port == 0 { |
|
|
s.logger.Errorf("[%s] listen failed for: %s %v, retrying on another port", client.id, bind, err) |
|
|
goto listen |
|
|
} |
|
|
s.logger.Errorf("[%s] listen failed for: %s %v", client.id, bind, err) |
|
|
req.Reply(false, []byte{}) |
|
|
return nil, nil, fmt.Errorf("unable to listen") |
|
|
} |
|
|
port := ln.Addr().(*net.TCPAddr).Port |
|
|
bind = fmt.Sprintf("%s:%d", payload.Addr, port) |
|
|
|
|
|
s.logger.Debugf("[%s] listening on %s", client.id, bind) |
|
|
reply := tcpIPForwardPayloadReply{uint32(port)} |
|
|
req.Reply(true, ssh.Marshal(&reply)) |
|
|
|
|
|
return ln, &bindInfo{bind, uint32(port), payload.Addr}, nil |
|
|
} |
|
|
|
|
|
func (s *SSHServer) handleForwardTCPIPTransfer(c ssh.Channel, conn net.Conn) { |
|
|
defer conn.Close() |
|
|
defer c.Close() |
|
|
done := make(chan struct{}) |
|
|
|
|
|
go func() { |
|
|
io.Copy(c, conn) |
|
|
done <- struct{}{} |
|
|
}() |
|
|
|
|
|
go func() { |
|
|
io.Copy(conn, c) |
|
|
done <- struct{}{} |
|
|
}() |
|
|
|
|
|
<-done |
|
|
} |
|
|
|