File size: 4,491 Bytes
6a5b8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""

Shadowsocks Protocol Implementation

"""

import os
import asyncio
import hashlib
from typing import Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
import logging

logger = logging.getLogger(__name__)

class ShadowsocksProtocol:
    CHUNK_SIZE = 8192
    
    def __init__(self, access_key: str):
        self.access_key = access_key
        self.cipher = self._create_cipher()
        self.buffer = bytearray()
        
    def _create_cipher(self) -> ChaCha20Poly1305:
        """Create ChaCha20-Poly1305 cipher"""
        key = hashlib.sha256(self.access_key.encode()).digest()
        return ChaCha20Poly1305(key)

    async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        """Handle client connection"""
        try:
            # Read and decrypt initial packet
            data = await reader.read(self.CHUNK_SIZE)
            if not data:
                return
                
            # Extract target address
            decrypted = self._decrypt_packet(data)
            target_addr = self._extract_address(decrypted)
            if not target_addr:
                logger.error("Invalid target address")
                return
                
            # Connect to target
            target_reader, target_writer = await asyncio.open_connection(
                target_addr[0], target_addr[1]
            )
            
            # Start bidirectional forwarding
            await self._proxy_data(reader, writer, target_reader, target_writer)
            
        except Exception as e:
            logger.error(f"Connection error: {e}")
        finally:
            writer.close()
            await writer.wait_closed()

    async def _proxy_data(self,

                         client_reader: asyncio.StreamReader,

                         client_writer: asyncio.StreamWriter,

                         target_reader: asyncio.StreamReader,

                         target_writer: asyncio.StreamWriter):
        """Handle bidirectional data forwarding"""
        async def forward(reader: asyncio.StreamReader,

                        writer: asyncio.StreamWriter,

                        encrypt: bool = False):
            try:
                while True:
                    data = await reader.read(self.CHUNK_SIZE)
                    if not data:
                        break
                    if encrypt:
                        data = self._encrypt_packet(data)
                    writer.write(data)
                    await writer.drain()
            except Exception as e:
                logger.error(f"Forward error: {e}")
            finally:
                writer.close()
                await writer.wait_closed()

        await asyncio.gather(
            forward(client_reader, target_writer, encrypt=False),
            forward(target_reader, client_writer, encrypt=True)
        )

    def _encrypt_packet(self, data: bytes) -> bytes:
        """Encrypt a packet"""
        nonce = os.urandom(12)
        encrypted = self.cipher.encrypt(nonce, data, None)
        return nonce + encrypted

    def _decrypt_packet(self, data: bytes) -> bytes:
        """Decrypt a packet"""
        nonce, ciphertext = data[:12], data[12:]
        return self.cipher.decrypt(nonce, ciphertext, None)

    def _extract_address(self, data: bytes) -> Optional[Tuple[str, int]]:
        """Extract address from Shadowsocks address header"""
        try:
            atyp = data[0]  # Address type
            
            if atyp == 1:  # IPv4
                addr = '.'.join(str(b) for b in data[1:5])
                port = int.from_bytes(data[5:7], 'big')
                payload_start = 7
            elif atyp == 3:  # Domain name
                length = data[1]
                addr = data[2:2+length].decode()
                port = int.from_bytes(data[2+length:4+length], 'big')
                payload_start = 4 + length
            elif atyp == 4:  # IPv6
                addr = ':'.join(format(b, '02x') for b in data[1:17])
                port = int.from_bytes(data[17:19], 'big')
                payload_start = 19
            else:
                return None
                
            return addr, port
            
        except Exception as e:
            logger.error(f"Error extracting address: {e}")
            return None