JRNET / services /vpn_service.py
Factor Studios
Upload 96 files
6a5b8d8 verified
"""
VPN service implementation with database integration
"""
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
import json
from models.database import (
User,
VPNSession,
UserVPNConfig,
BandwidthUsage,
ServerConfig
)
from core.nat_engine import NATEngine
from core.outline_server import OutlineServer
from core.ikev2_server import IKEv2Server
from schemas.vpn import VPNConfigResponse, VPNSessionResponse, VPNServerStats
class VPNService:
def __init__(self, db: Session):
self.db = db
self.nat_engine = NATEngine()
self.outline_server = OutlineServer()
self.ikev2_server = IKEv2Server()
async def get_user_config(self, user_id: int) -> VPNConfigResponse:
"""Get VPN configuration for a user"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError("User not found")
# Get or create user's VPN config
config = self.db.query(UserVPNConfig).filter(
UserVPNConfig.user_id == user_id
).first()
if not config:
# Generate new configuration based on protocol
if user.vpn_protocol == "outline":
config_data = await self.outline_server.generate_user_config(user_id)
elif user.vpn_protocol == "ikev2":
config_data = await self.ikev2_server.generate_user_config(user_id)
else:
raise ValueError(f"Unsupported VPN protocol: {user.vpn_protocol}")
# Save configuration to database
config = UserVPNConfig(
user_id=user_id,
protocol=user.vpn_protocol,
config_data=json.dumps(config_data)
)
self.db.add(config)
self.db.commit()
config_data = json.loads(config.config_data)
return VPNConfigResponse(**config_data)
async def get_user_sessions(self, user_id: int) -> List[VPNSessionResponse]:
"""Get active VPN sessions for a user"""
sessions = self.db.query(VPNSession).filter(
VPNSession.user_id == user_id,
VPNSession.status == "active"
).all()
return [
VPNSessionResponse(
session_id=str(session.id),
start_time=session.start_time,
last_active=session.end_time or datetime.utcnow(),
protocol=session.protocol,
client_ip=session.client_ip,
bytes_sent=session.bytes_sent,
bytes_received=session.bytes_received,
status=session.status
)
for session in sessions
]
async def get_user_stats(self, user_id: int) -> VPNServerStats:
"""Get VPN usage statistics for a user"""
# Get active sessions
active_sessions = self.db.query(VPNSession).filter(
VPNSession.user_id == user_id,
VPNSession.status == "active"
).all()
# Get total bytes transferred
bandwidth_usage = self.db.query(BandwidthUsage).filter(
BandwidthUsage.user_id == user_id
).all()
total_bytes_up = sum(usage.bytes_up for usage in bandwidth_usage)
total_bytes_down = sum(usage.bytes_down for usage in bandwidth_usage)
# Get last connection
last_session = self.db.query(VPNSession).filter(
VPNSession.user_id == user_id
).order_by(VPNSession.start_time.desc()).first()
# Calculate bandwidth usage by protocol
protocol_usage = {}
for usage in bandwidth_usage:
if usage.protocol not in protocol_usage:
protocol_usage[usage.protocol] = {
'up': 0,
'down': 0
}
protocol_usage[usage.protocol]['up'] += usage.bytes_up
protocol_usage[usage.protocol]['down'] += usage.bytes_down
return VPNServerStats(
total_data_transferred=total_bytes_up + total_bytes_down,
active_sessions=len(active_sessions),
total_session_time=sum(
int((s.end_time or datetime.utcnow() - s.start_time).total_seconds())
for s in active_sessions
),
last_connection=last_session.start_time if last_session else None,
bandwidth_usage=protocol_usage
)
async def disconnect_session(self, session_id: str, user_id: int) -> bool:
"""Disconnect a specific VPN session"""
session = self.db.query(VPNSession).filter(
VPNSession.id == session_id,
VPNSession.user_id == user_id,
VPNSession.status == "active"
).first()
if not session:
return False
# Disconnect from VPN server
if session.protocol == "outline":
await self.outline_server.disconnect_session(session_id)
elif session.protocol == "ikev2":
await self.ikev2_server.disconnect_session(session_id)
# Update session status in database
session.status = "disconnected"
session.end_time = datetime.utcnow()
self.db.commit()
return True
async def update_bandwidth_usage(self, user_id: int, protocol: str, bytes_up: int, bytes_down: int):
"""Update bandwidth usage statistics"""
usage = BandwidthUsage(
user_id=user_id,
protocol=protocol,
bytes_up=bytes_up,
bytes_down=bytes_down
)
self.db.add(usage)
self.db.commit()