File size: 5,869 Bytes
6a5b8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""

VPN service implementation with database integration

"""
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import json

from models.database import (
    User, 
    VPNSession, 
    UserVPNConfig, 
    BandwidthUsage,
    ServerConfig
)
from core.nat_engine import NATEngine
from core.outline_server import OutlineServer
from core.ikev2_server import IKEv2Server
from schemas.vpn import VPNConfigResponse, VPNSessionResponse, VPNServerStats

class VPNService:
    def __init__(self, db: Session):
        self.db = db
        self.nat_engine = NATEngine()
        self.outline_server = OutlineServer()
        self.ikev2_server = IKEv2Server()

    async def get_user_config(self, user_id: int) -> VPNConfigResponse:
        """Get VPN configuration for a user"""
        user = self.db.query(User).filter(User.id == user_id).first()
        if not user:
            raise ValueError("User not found")

        # Get or create user's VPN config
        config = self.db.query(UserVPNConfig).filter(
            UserVPNConfig.user_id == user_id
        ).first()

        if not config:
            # Generate new configuration based on protocol
            if user.vpn_protocol == "outline":
                config_data = await self.outline_server.generate_user_config(user_id)
            elif user.vpn_protocol == "ikev2":
                config_data = await self.ikev2_server.generate_user_config(user_id)
            else:
                raise ValueError(f"Unsupported VPN protocol: {user.vpn_protocol}")

            # Save configuration to database
            config = UserVPNConfig(
                user_id=user_id,
                protocol=user.vpn_protocol,
                config_data=json.dumps(config_data)
            )
            self.db.add(config)
            self.db.commit()
        
        config_data = json.loads(config.config_data)
        return VPNConfigResponse(**config_data)

    async def get_user_sessions(self, user_id: int) -> List[VPNSessionResponse]:
        """Get active VPN sessions for a user"""
        sessions = self.db.query(VPNSession).filter(
            VPNSession.user_id == user_id,
            VPNSession.status == "active"
        ).all()

        return [
            VPNSessionResponse(
                session_id=str(session.id),
                start_time=session.start_time,
                last_active=session.end_time or datetime.utcnow(),
                protocol=session.protocol,
                client_ip=session.client_ip,
                bytes_sent=session.bytes_sent,
                bytes_received=session.bytes_received,
                status=session.status
            )
            for session in sessions
        ]

    async def get_user_stats(self, user_id: int) -> VPNServerStats:
        """Get VPN usage statistics for a user"""
        # Get active sessions
        active_sessions = self.db.query(VPNSession).filter(
            VPNSession.user_id == user_id,
            VPNSession.status == "active"
        ).all()
        
        # Get total bytes transferred
        bandwidth_usage = self.db.query(BandwidthUsage).filter(
            BandwidthUsage.user_id == user_id
        ).all()
        
        total_bytes_up = sum(usage.bytes_up for usage in bandwidth_usage)
        total_bytes_down = sum(usage.bytes_down for usage in bandwidth_usage)
        
        # Get last connection
        last_session = self.db.query(VPNSession).filter(
            VPNSession.user_id == user_id
        ).order_by(VPNSession.start_time.desc()).first()
        
        # Calculate bandwidth usage by protocol
        protocol_usage = {}
        for usage in bandwidth_usage:
            if usage.protocol not in protocol_usage:
                protocol_usage[usage.protocol] = {
                    'up': 0,
                    'down': 0
                }
            protocol_usage[usage.protocol]['up'] += usage.bytes_up
            protocol_usage[usage.protocol]['down'] += usage.bytes_down
        
        return VPNServerStats(
            total_data_transferred=total_bytes_up + total_bytes_down,
            active_sessions=len(active_sessions),
            total_session_time=sum(
                int((s.end_time or datetime.utcnow() - s.start_time).total_seconds())
                for s in active_sessions
            ),
            last_connection=last_session.start_time if last_session else None,
            bandwidth_usage=protocol_usage
        )

    async def disconnect_session(self, session_id: str, user_id: int) -> bool:
        """Disconnect a specific VPN session"""
        session = self.db.query(VPNSession).filter(
            VPNSession.id == session_id,
            VPNSession.user_id == user_id,
            VPNSession.status == "active"
        ).first()
        
        if not session:
            return False
            
        # Disconnect from VPN server
        if session.protocol == "outline":
            await self.outline_server.disconnect_session(session_id)
        elif session.protocol == "ikev2":
            await self.ikev2_server.disconnect_session(session_id)
            
        # Update session status in database
        session.status = "disconnected"
        session.end_time = datetime.utcnow()
        self.db.commit()
        
        return True

    async def update_bandwidth_usage(self, user_id: int, protocol: str, bytes_up: int, bytes_down: int):
        """Update bandwidth usage statistics"""
        usage = BandwidthUsage(
            user_id=user_id,
            protocol=protocol,
            bytes_up=bytes_up,
            bytes_down=bytes_down
        )
        self.db.add(usage)
        self.db.commit()