| """
|
| VPN service implementation with database integration
|
| """
|
| from sqlalchemy.orm import Session
|
| from typing import List, Optional
|
| from datetime import datetime
|
| import json
|
|
|
| from models.database import (
|
| User,
|
| VPNSession,
|
| UserVPNConfig,
|
| BandwidthUsage,
|
| ServerConfig
|
| )
|
| from core.nat_engine import NATEngine
|
| from core.outline_server import OutlineServer
|
| from core.ikev2_server import IKEv2Server
|
| from schemas.vpn import VPNConfigResponse, VPNSessionResponse, VPNServerStats
|
|
|
| class VPNService:
|
| def __init__(self, db: Session):
|
| self.db = db
|
| self.nat_engine = NATEngine()
|
| self.outline_server = OutlineServer()
|
| self.ikev2_server = IKEv2Server()
|
|
|
| async def get_user_config(self, user_id: int) -> VPNConfigResponse:
|
| """Get VPN configuration for a user"""
|
| user = self.db.query(User).filter(User.id == user_id).first()
|
| if not user:
|
| raise ValueError("User not found")
|
|
|
|
|
| config = self.db.query(UserVPNConfig).filter(
|
| UserVPNConfig.user_id == user_id
|
| ).first()
|
|
|
| if not config:
|
|
|
| if user.vpn_protocol == "outline":
|
| config_data = await self.outline_server.generate_user_config(user_id)
|
| elif user.vpn_protocol == "ikev2":
|
| config_data = await self.ikev2_server.generate_user_config(user_id)
|
| else:
|
| raise ValueError(f"Unsupported VPN protocol: {user.vpn_protocol}")
|
|
|
|
|
| config = UserVPNConfig(
|
| user_id=user_id,
|
| protocol=user.vpn_protocol,
|
| config_data=json.dumps(config_data)
|
| )
|
| self.db.add(config)
|
| self.db.commit()
|
|
|
| config_data = json.loads(config.config_data)
|
| return VPNConfigResponse(**config_data)
|
|
|
| async def get_user_sessions(self, user_id: int) -> List[VPNSessionResponse]:
|
| """Get active VPN sessions for a user"""
|
| sessions = self.db.query(VPNSession).filter(
|
| VPNSession.user_id == user_id,
|
| VPNSession.status == "active"
|
| ).all()
|
|
|
| return [
|
| VPNSessionResponse(
|
| session_id=str(session.id),
|
| start_time=session.start_time,
|
| last_active=session.end_time or datetime.utcnow(),
|
| protocol=session.protocol,
|
| client_ip=session.client_ip,
|
| bytes_sent=session.bytes_sent,
|
| bytes_received=session.bytes_received,
|
| status=session.status
|
| )
|
| for session in sessions
|
| ]
|
|
|
| async def get_user_stats(self, user_id: int) -> VPNServerStats:
|
| """Get VPN usage statistics for a user"""
|
|
|
| active_sessions = self.db.query(VPNSession).filter(
|
| VPNSession.user_id == user_id,
|
| VPNSession.status == "active"
|
| ).all()
|
|
|
|
|
| bandwidth_usage = self.db.query(BandwidthUsage).filter(
|
| BandwidthUsage.user_id == user_id
|
| ).all()
|
|
|
| total_bytes_up = sum(usage.bytes_up for usage in bandwidth_usage)
|
| total_bytes_down = sum(usage.bytes_down for usage in bandwidth_usage)
|
|
|
|
|
| last_session = self.db.query(VPNSession).filter(
|
| VPNSession.user_id == user_id
|
| ).order_by(VPNSession.start_time.desc()).first()
|
|
|
|
|
| protocol_usage = {}
|
| for usage in bandwidth_usage:
|
| if usage.protocol not in protocol_usage:
|
| protocol_usage[usage.protocol] = {
|
| 'up': 0,
|
| 'down': 0
|
| }
|
| protocol_usage[usage.protocol]['up'] += usage.bytes_up
|
| protocol_usage[usage.protocol]['down'] += usage.bytes_down
|
|
|
| return VPNServerStats(
|
| total_data_transferred=total_bytes_up + total_bytes_down,
|
| active_sessions=len(active_sessions),
|
| total_session_time=sum(
|
| int((s.end_time or datetime.utcnow() - s.start_time).total_seconds())
|
| for s in active_sessions
|
| ),
|
| last_connection=last_session.start_time if last_session else None,
|
| bandwidth_usage=protocol_usage
|
| )
|
|
|
| async def disconnect_session(self, session_id: str, user_id: int) -> bool:
|
| """Disconnect a specific VPN session"""
|
| session = self.db.query(VPNSession).filter(
|
| VPNSession.id == session_id,
|
| VPNSession.user_id == user_id,
|
| VPNSession.status == "active"
|
| ).first()
|
|
|
| if not session:
|
| return False
|
|
|
|
|
| if session.protocol == "outline":
|
| await self.outline_server.disconnect_session(session_id)
|
| elif session.protocol == "ikev2":
|
| await self.ikev2_server.disconnect_session(session_id)
|
|
|
|
|
| session.status = "disconnected"
|
| session.end_time = datetime.utcnow()
|
| self.db.commit()
|
|
|
| return True
|
|
|
| async def update_bandwidth_usage(self, user_id: int, protocol: str, bytes_up: int, bytes_down: int):
|
| """Update bandwidth usage statistics"""
|
| usage = BandwidthUsage(
|
| user_id=user_id,
|
| protocol=protocol,
|
| bytes_up=bytes_up,
|
| bytes_down=bytes_down
|
| )
|
| self.db.add(usage)
|
| self.db.commit()
|
|
|