streamion / server.go
cursorpro's picture
Upload 57 files
5ec2e9b verified
package main
import (
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"log"
"net"
"net/http"
"net/netip"
"strings"
"time"
"github.com/caarlos0/env"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type params struct {
User string `env:"PROXY_USER" envDefault:""`
Password string `env:"PROXY_PASS" envDefault:""`
Port string `env:"PORT" envDefault:"8080"`
// WireGuard Params
WgPrivateKey string `env:"WIREGUARD_INTERFACE_PRIVATE_KEY"`
WgAddress string `env:"WIREGUARD_INTERFACE_ADDRESS"` // e.g., 10.0.0.2/32
WgPeerPublicKey string `env:"WIREGUARD_PEER_PUBLIC_KEY"`
WgPeerEndpoint string `env:"WIREGUARD_PEER_ENDPOINT"` // e.g., 1.2.3.4:51820
WgDNS string `env:"WIREGUARD_INTERFACE_DNS" envDefault:"1.1.1.1"`
}
var tnet *netstack.Net
func handleTunneling(w http.ResponseWriter, r *http.Request) {
dest := r.URL.Host
if dest == "" {
dest = r.Host
}
// Hijack the connection first to allow custom response writing
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
client_conn, _, err := hijacker.Hijack()
if err != nil {
// If hijack fails, we can't do much as headers might be sent or connection broken
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
var dest_conn net.Conn
if tnet == nil {
dest_conn, err = net.DialTimeout("tcp", dest, 10*time.Second)
} else {
// Use tnet.Dial to connect through WireGuard
dest_conn, err = tnet.Dial("tcp", dest)
}
if err != nil {
log.Printf("[ERROR] TUNNEL Dial failed to %s: %v", dest, err)
// Send a 503 to the client through the hijacked connection and close
// Simple HTTP response since we hijacked
client_conn.Write([]byte("HTTP/1.1 503 Service Unavailable\r\n\r\n"))
client_conn.Close()
return
}
// Write 200 Connection Established to the client
// This signals the client that the tunnel is ready
_, err = client_conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
log.Printf("[ERROR] TUNNEL Write 200 failed: %v", err)
dest_conn.Close()
client_conn.Close()
return
}
go transfer(dest_conn, client_conn)
go transfer(client_conn, dest_conn)
}
func transfer(destination io.WriteCloser, source io.ReadCloser) {
defer destination.Close()
defer source.Close()
io.Copy(destination, source)
}
func handleHTTP(w http.ResponseWriter, r *http.Request) {
transport := http.DefaultTransport.(*http.Transport).Clone()
if tnet != nil {
// Use tnet.DialContext for HTTP requests
transport.DialContext = tnet.DialContext
}
resp, err := transport.RoundTrip(r)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func handleDebug(w http.ResponseWriter, r *http.Request) {
if tnet == nil {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Error: WireGuard not initialized (Direct Mode)"))
return
}
client := &http.Client{
Transport: &http.Transport{
DialContext: tnet.DialContext,
},
Timeout: 10 * time.Second,
}
resp, err := client.Get("http://ifconfig.me")
if err != nil {
log.Printf("[DEBUG] VPN Test Failed: %v", err)
http.Error(w, fmt.Sprintf("VPN Connection Failed: %v", err), http.StatusServiceUnavailable)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read response: %v", err), http.StatusInternalServerError)
return
}
log.Printf("[DEBUG] VPN Test Success. IP: %s", string(body))
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("VPN Connected! Public IP: %s", string(body))))
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func startWireGuard(cfg params) error {
if cfg.WgPrivateKey == "" || cfg.WgPeerEndpoint == "" {
log.Println("[INFO] WireGuard config missing, running in DIRECT mode (no VPN)")
return nil
}
log.Println("[INFO] Initializing Userspace WireGuard...")
localIPs := []netip.Addr{}
if cfg.WgAddress != "" {
// Handle CIDR notation if present (e.g., 10.0.0.2/32)
addrStr := strings.Split(cfg.WgAddress, "/")[0]
addr, err := netip.ParseAddr(addrStr)
if err == nil {
localIPs = append(localIPs, addr)
log.Printf("[INFO] Local VPN IP: %s", addr)
} else {
log.Printf("[WARN] Failed to parse local IP: %v", err)
}
}
dnsIP, err := netip.ParseAddr(cfg.WgDNS)
if err != nil {
log.Printf("[WARN] Failed to parse DNS IP, using default: %v", err)
dnsIP, _ = netip.ParseAddr("1.1.1.1")
}
log.Printf("[INFO] DNS Server: %s", dnsIP)
log.Println("[INFO] Creating virtual network interface...")
tunDev, tnetInstance, err := netstack.CreateNetTUN(
localIPs,
[]netip.Addr{dnsIP},
1420,
)
if err != nil {
return fmt.Errorf("failed to create TUN: %w", err)
}
tnet = tnetInstance
log.Println("[INFO] Virtual TUN device created successfully")
log.Println("[INFO] Initializing WireGuard device...")
dev := device.NewDevice(tunDev, conn.NewDefaultBind(), device.NewLogger(device.LogLevelSilent, ""))
log.Printf("[INFO] Configuring peer endpoint: %s", cfg.WgPeerEndpoint)
// Convert keys from Base64 to Hex
// wireguard-go expects hex keys in UAPI, but inputs are usually Base64
privateKeyHex, err := base64ToHex(cfg.WgPrivateKey)
if err != nil {
return fmt.Errorf("invalid private key (base64 decode failed): %w", err)
}
publicKeyHex, err := base64ToHex(cfg.WgPeerPublicKey)
if err != nil {
return fmt.Errorf("invalid peer public key (base64 decode failed): %w", err)
}
uapi := fmt.Sprintf(`private_key=%s
public_key=%s
endpoint=%s
allowed_ip=0.0.0.0/0
`, privateKeyHex, publicKeyHex, cfg.WgPeerEndpoint)
if err := dev.IpcSet(uapi); err != nil {
return fmt.Errorf("failed to configure device: %w", err)
}
log.Println("[INFO] WireGuard peer configured")
if err := dev.Up(); err != nil {
return fmt.Errorf("failed to bring up device: %w", err)
}
log.Println("[SUCCESS] WireGuard interface is UP - All traffic will route through VPN")
return nil
}
func main() {
log.SetFlags(log.LstdFlags | log.Lmsgprefix)
log.Println("[STARTUP] Initializing HTTP Proxy with Userspace WireGuard")
cfg := params{}
if err := env.Parse(&cfg); err != nil {
log.Printf("[WARN] Config parse warning: %+v\n", err)
}
log.Printf("[CONFIG] Proxy Port: %s", cfg.Port)
if cfg.User != "" {
log.Printf("[CONFIG] Authentication: Enabled (user: %s)", cfg.User)
} else {
log.Println("[CONFIG] Authentication: Disabled")
}
if err := startWireGuard(cfg); err != nil {
log.Fatalf("[FATAL] Failed to start WireGuard: %v", err)
}
log.Printf("[STARTUP] Starting HTTP proxy server on port %s\n", cfg.Port)
server := &http.Server{
Addr: ":" + cfg.Port,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if cfg.User != "" && cfg.Password != "" {
user, pass, ok := r.BasicAuth()
if !ok || user != cfg.User || pass != cfg.Password {
log.Printf("[AUTH] Unauthorized access attempt from %s", r.RemoteAddr)
w.Header().Set("Proxy-Authenticate", `Basic realm="Proxy"`)
http.Error(w, "Unauthorized", http.StatusProxyAuthRequired)
return
}
}
// Handle CONNECT (HTTPS tunnel)
if r.Method == http.MethodConnect {
log.Printf("[CONNECT] %s -> %s", r.RemoteAddr, r.Host)
handleTunneling(w, r)
return
}
// Direct requests to the proxy server (Health check & Debug)
// We check r.URL.Host == "" which means it's a direct request, not a proxy request
if r.URL.Host == "" {
if r.URL.Path == "/" {
log.Printf("[HEALTH] Health check from %s", r.RemoteAddr)
w.WriteHeader(http.StatusOK)
if tnet != nil {
w.Write([]byte("Proxy Running via Userspace WireGuard"))
} else {
w.Write([]byte("Proxy Running in Direct Mode (No VPN)"))
}
return
}
if r.URL.Path == "/debug" {
log.Printf("[DEBUG] Debug check from %s", r.RemoteAddr)
handleDebug(w, r)
return
}
}
// Proxy HTTP requests
log.Printf("[HTTP] %s %s -> %s", r.Method, r.RemoteAddr, r.URL.String())
handleHTTP(w, r)
}),
}
log.Println("[READY] Proxy server is ready to accept connections")
if err := server.ListenAndServe(); err != nil {
log.Fatalf("[FATAL] Server error: %v", err)
}
}
func base64ToHex(b64 string) (string, error) {
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", err
}
return hex.EncodeToString(decoded), nil
}