| package tproxy |
|
|
| import ( |
| "errors" |
| "net" |
| "time" |
|
|
| "github.com/apernet/go-tproxy" |
| "github.com/apernet/hysteria/core/v2/client" |
| ) |
|
|
| const ( |
| udpBufferSize = 4096 |
| defaultTimeout = 60 * time.Second |
| ) |
|
|
| type UDPTProxy struct { |
| HyClient client.Client |
| Timeout time.Duration |
| EventLogger UDPEventLogger |
| } |
|
|
| type UDPEventLogger interface { |
| Connect(addr, reqAddr net.Addr) |
| Error(addr, reqAddr net.Addr, err error) |
| } |
|
|
| func (r *UDPTProxy) ListenAndServe(laddr *net.UDPAddr) error { |
| conn, err := tproxy.ListenUDP("udp", laddr) |
| if err != nil { |
| return err |
| } |
| defer conn.Close() |
| buf := make([]byte, udpBufferSize) |
| for { |
| |
| |
| |
| n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf) |
| if err != nil { |
| return err |
| } |
| r.newPair(srcAddr, dstAddr, buf[:n]) |
| } |
| } |
|
|
| func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) { |
| if r.EventLogger != nil { |
| r.EventLogger.Connect(srcAddr, dstAddr) |
| } |
| var closeErr error |
| defer func() { |
| |
| |
| if r.EventLogger != nil && closeErr != nil { |
| r.EventLogger.Error(srcAddr, dstAddr, closeErr) |
| } |
| }() |
| conn, err := tproxy.DialUDP("udp", dstAddr, srcAddr) |
| if err != nil { |
| closeErr = err |
| return |
| } |
| hyConn, err := r.HyClient.UDP() |
| if err != nil { |
| _ = conn.Close() |
| closeErr = err |
| return |
| } |
| |
| err = hyConn.Send(initPkt, dstAddr.String()) |
| if err != nil { |
| _ = conn.Close() |
| _ = hyConn.Close() |
| closeErr = err |
| return |
| } |
| |
| go func() { |
| err := r.forwarding(conn, hyConn, dstAddr.String()) |
| _ = conn.Close() |
| _ = hyConn.Close() |
| if r.EventLogger != nil { |
| var netErr net.Error |
| if errors.As(err, &netErr) && netErr.Timeout() { |
| |
| err = nil |
| } |
| r.EventLogger.Error(srcAddr, dstAddr, err) |
| } |
| }() |
| } |
|
|
| func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) error { |
| errChan := make(chan error, 2) |
| |
| go func() { |
| for { |
| bs, _, err := hyConn.Receive() |
| if err != nil { |
| errChan <- err |
| return |
| } |
| _, err = conn.Write(bs) |
| if err != nil { |
| errChan <- err |
| return |
| } |
| _ = r.updateConnDeadline(conn) |
| } |
| }() |
| |
| go func() { |
| buf := make([]byte, udpBufferSize) |
| for { |
| _ = r.updateConnDeadline(conn) |
| n, err := conn.Read(buf) |
| if n > 0 { |
| err := hyConn.Send(buf[:n], dst) |
| if err != nil { |
| errChan <- err |
| return |
| } |
| } |
| if err != nil { |
| errChan <- err |
| return |
| } |
| } |
| }() |
| return <-errChan |
| } |
|
|
| func (r *UDPTProxy) updateConnDeadline(conn *net.UDPConn) error { |
| if r.Timeout == 0 { |
| return conn.SetReadDeadline(time.Now().Add(defaultTimeout)) |
| } else { |
| return conn.SetReadDeadline(time.Now().Add(r.Timeout)) |
| } |
| } |
|
|