ColaCoca's picture
jeju 모델 업로드
8eb2cb0
package server
import (
"fmt"
"io"
"net"
"os"
"sync"
"time"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
const (
minPort = 55000
maxPort = 65000
)
// SSHServer defines SSH server instance.
type SSHServer struct {
mu sync.Mutex
opts *Options
listener net.Listener
config *ssh.ServerConfig
running chan error
isRunning bool
clients map[string]*client
addr string
domain string
logger *zap.SugaredLogger
}
type client struct {
mu sync.Mutex
id string
tcpConn net.Conn
sshConn *ssh.ServerConn
ch ssh.Channel
listeners map[string]net.Listener
addr string
port uint32
}
func (c *client) write(data string) {
if c.ch != nil {
io.WriteString(c.ch, data)
}
}
// NewSSHServer returns new instance of SSHServer.
func NewSSHServer(opts *Options, logger *zap.SugaredLogger) *SSHServer {
return &SSHServer{
opts: opts,
config: &ssh.ServerConfig{
NoClientAuth: true,
},
running: make(chan error, 1),
clients: make(map[string]*client),
logger: logger,
isRunning: true,
}
}
// Run starts the SSH server.
func (s *SSHServer) Run() error {
privateKeyContent, err := os.ReadFile(s.opts.PrivateKey)
if err != nil {
return err
}
private, err := ssh.ParsePrivateKey(privateKeyContent)
if err != nil {
return err
}
s.config.AddHostKey(private)
s.addr = s.opts.SSHAddr
s.domain = s.opts.Domain
go s.closeWith(s.listen())
return nil
}
// Close closes and stops the SSH server.
func (s *SSHServer) Close() error {
s.closeWith(nil)
return s.listener.Close()
}
// Wait waits for server to be stopped
func (s *SSHServer) Wait() error {
if !s.isRunning {
return fmt.Errorf("already closed")
}
return <-s.running
}
func (s *SSHServer) closeWith(err error) {
if !s.isRunning {
return
}
s.isRunning = false
s.running <- err
}
func (s *SSHServer) listen() error {
listener, err := net.Listen("tcp", s.addr)
if err != nil {
return err
}
s.listener = listener
s.logger.Infof("starting SSH server on %s", s.addr)
for {
tcpConn, err := s.listener.Accept()
if err != nil {
s.logger.Errorf("failed to accept incoming connection: %v", err)
continue
}
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, s.config)
if err != nil {
s.logger.Errorf("failed to handshake: %v", err)
continue
}
genid:
id := randID()
if _, ok := s.clients[id]; ok {
goto genid
}
c := &client{
id: id,
tcpConn: tcpConn,
sshConn: sshConn,
listeners: make(map[string]net.Listener),
addr: "",
port: 0,
}
s.logger.Infof("new SSH connection from %s (%s)", sshConn.RemoteAddr().String(), sshConn.ClientVersion())
go func(c *client) {
err := c.sshConn.Wait()
s.logger.Infof("[%s] SSH connection closed: %v", c.id, err)
c.mu.Lock()
for bind, listener := range c.listeners {
s.logger.Debugf("[%s] closing listener bound to %s", c.id, bind)
listener.Close()
}
c.mu.Unlock()
s.mu.Lock()
delete(s.clients, c.id)
s.mu.Unlock()
}(c)
go s.handleRequests(c, reqs)
go s.handleChannels(c, chans)
}
}
func (s *SSHServer) handleChannels(client *client, chans <-chan ssh.NewChannel) {
for nch := range chans {
chconn, _, err := nch.Accept()
if err != nil {
s.logger.Errorf("[%s] could not accept channel: %v", client.id, err)
return
}
client.ch = chconn
}
}
func (s *SSHServer) handleRequests(client *client, reqs <-chan *ssh.Request) {
for req := range reqs {
client.tcpConn.SetDeadline(time.Now().Add(2 * time.Minute))
if req.Type == "set-id" {
var payload idRequestPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
s.logger.Errorf("[%s] Unable to unmarshal payload: %v", client.id, err)
}
if payload.ID != "" {
if _, ok := s.clients[payload.ID]; !ok {
s.mu.Lock()
delete(s.clients, client.id)
client.id = payload.ID
s.clients[client.id] = client
s.mu.Unlock()
}
}
req.Reply(true, []byte{})
continue
}
if req.Type == "tcpip-forward" {
listener, bindInfo, err := s.handleForward(client, req)
if err != nil {
s.logger.Errorf("[%s] error, disconnecting: %v", client.id, err)
client.tcpConn.Close()
continue
}
client.addr = bindInfo.Addr
client.port = bindInfo.Port
client.mu.Lock()
client.listeners[bindInfo.Bound] = listener
client.mu.Unlock()
s.mu.Lock()
s.clients[client.id] = client
s.mu.Unlock()
go s.handleListener(client, bindInfo, listener)
if client.ch != nil {
data := clientResponse{
id: client.id,
domain: s.domain,
port: client.port,
}
renderMessage(data, client.ch)
renderTable(data, client.ch)
}
} else {
req.Reply(false, []byte{})
}
}
}
func (s *SSHServer) handleListener(client *client, bindInfo *bindInfo, listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
neterr := err.(net.Error)
if neterr.Timeout() {
s.logger.Errorf("[%s] accept failed with timeout: %v", client.id, err)
continue
}
if neterr.Temporary() {
s.logger.Errorf("[%s] accept failed with temporary: %v", client.id, err)
continue
}
break
}
go s.handleForwardTCPIP(client, bindInfo, conn)
}
}
func (s *SSHServer) handleForwardTCPIP(client *client, bindInfo *bindInfo, conn net.Conn) {
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
raddr := remoteAddr.IP.String()
rport := uint32(remoteAddr.Port)
payload := forwardedTCPPayload{bindInfo.Addr, bindInfo.Port, raddr, rport}
mpayload := ssh.Marshal(&payload)
// open channel with client
c, requests, err := client.sshConn.OpenChannel("forwarded-tcpip", mpayload)
if err != nil {
s.logger.Errorf("[%s] unable to get channel: %v. Hanging up requesting party!", client.id, err)
conn.Close()
return
}
s.logger.Debugf("[%s] channel opened for client %s:%d <-> %s", client.id, bindInfo.Addr, bindInfo.Port, remoteAddr.String())
go ssh.DiscardRequests(requests)
go s.handleForwardTCPIPTransfer(c, conn)
}
func (s *SSHServer) handleForward(client *client, req *ssh.Request) (net.Listener, *bindInfo, error) {
var payload tcpIPForwardPayload
if err := ssh.Unmarshal(req.Payload, &payload); err != nil {
s.logger.Errorf("[%s] unable to unmarshal payload: %v", client.id, err)
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("unable to parse payload")
}
s.logger.Debugf("[%s] request: %s %v %v", client.id, req.Type, req.WantReply, payload)
listen:
bind := fmt.Sprintf("%s:%d", payload.Addr, payload.Port)
if payload.Port == 0 {
bind = fmt.Sprintf("%s:%d", payload.Addr, randomPort(minPort, maxPort))
}
ln, err := net.Listen("tcp", bind)
if err != nil {
if payload.Port == 0 {
s.logger.Errorf("[%s] listen failed for: %s %v, retrying on another port", client.id, bind, err)
goto listen
}
s.logger.Errorf("[%s] listen failed for: %s %v", client.id, bind, err)
req.Reply(false, []byte{})
return nil, nil, fmt.Errorf("unable to listen")
}
port := ln.Addr().(*net.TCPAddr).Port
bind = fmt.Sprintf("%s:%d", payload.Addr, port)
s.logger.Debugf("[%s] listening on %s", client.id, bind)
reply := tcpIPForwardPayloadReply{uint32(port)}
req.Reply(true, ssh.Marshal(&reply))
return ln, &bindInfo{bind, uint32(port), payload.Addr}, nil
}
func (s *SSHServer) handleForwardTCPIPTransfer(c ssh.Channel, conn net.Conn) {
defer conn.Close()
defer c.Close()
done := make(chan struct{})
go func() {
io.Copy(c, conn)
done <- struct{}{}
}()
go func() {
io.Copy(conn, c)
done <- struct{}{}
}()
<-done
}