""" Enhanced User Model with Authentication and VPN Client Management This module provides comprehensive user management with security features, VPN client management, and session tracking capabilities. """ from werkzeug.security import generate_password_hash, check_password_hash from datetime import datetime, timedelta import secrets import jwt import re from flask import current_app from .user import db class User(db.Model): """Enhanced User model with authentication and VPN management""" __tablename__ = 'users' id = db.Column(db.Integer, primary_key=True) username = db.Column(db.String(80), unique=True, nullable=False, index=True) email = db.Column(db.String(120), unique=True, nullable=False, index=True) password_hash = db.Column(db.String(255), nullable=False) salt = db.Column(db.String(32), nullable=False) created_at = db.Column(db.DateTime, default=datetime.utcnow) last_login = db.Column(db.DateTime) is_active = db.Column(db.Boolean, default=True) is_admin = db.Column(db.Boolean, default=False) subscription_type = db.Column(db.String(20), default='free') subscription_expires = db.Column(db.DateTime) max_concurrent_connections = db.Column(db.Integer, default=1) bandwidth_limit_mbps = db.Column(db.Integer, default=10) email_verified = db.Column(db.Boolean, default=False) email_verification_token = db.Column(db.String(64)) two_factor_enabled = db.Column(db.Boolean, default=False) two_factor_secret = db.Column(db.String(32)) password_reset_token = db.Column(db.String(64)) password_reset_expires = db.Column(db.DateTime) failed_login_attempts = db.Column(db.Integer, default=0) account_locked_until = db.Column(db.DateTime) # Relationships vpn_clients = db.relationship('VPNClient', backref='user', lazy=True, cascade='all, delete-orphan') vpn_sessions = db.relationship('VPNSession', backref='user', lazy=True) def __init__(self, username, email, password=None): self.username = username self.email = email if password: self.set_password(password) self.email_verification_token = secrets.token_urlsafe(32) def set_password(self, password): """Set password with secure hashing and salt""" if not self.validate_password_strength(password): raise ValueError("Password does not meet security requirements") self.salt = secrets.token_hex(16) self.password_hash = generate_password_hash(password + self.salt) self.failed_login_attempts = 0 self.account_locked_until = None def check_password(self, password): """Verify password against hash""" if self.is_account_locked(): return False is_valid = check_password_hash(self.password_hash, password + self.salt) if is_valid: self.failed_login_attempts = 0 self.last_login = datetime.utcnow() else: self.failed_login_attempts += 1 if self.failed_login_attempts >= 5: self.account_locked_until = datetime.utcnow() + timedelta(minutes=30) return is_valid def is_account_locked(self): """Check if account is locked due to failed login attempts""" if self.account_locked_until and datetime.utcnow() < self.account_locked_until: return True elif self.account_locked_until and datetime.utcnow() >= self.account_locked_until: # Unlock account self.account_locked_until = None self.failed_login_attempts = 0 return False @staticmethod def validate_password_strength(password): """Validate password meets security requirements""" if len(password) < 8: return False if not re.search(r'[A-Z]', password): return False if not re.search(r'[a-z]', password): return False if not re.search(r'\d', password): return False if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password): return False return True @staticmethod def validate_email(email): """Validate email format""" pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' return re.match(pattern, email) is not None @staticmethod def validate_username(username): """Validate username format""" if len(username) < 3 or len(username) > 80: return False if not re.match(r'^[a-zA-Z0-9_-]+$', username): return False return True def generate_auth_token(self, expires_in=3600): """Generate JWT authentication token""" payload = { 'user_id': self.id, 'username': self.username, 'email': self.email, 'subscription_type': self.subscription_type, 'is_admin': self.is_admin, 'exp': datetime.utcnow() + timedelta(seconds=expires_in), 'iat': datetime.utcnow() } return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') def generate_refresh_token(self, expires_in=2592000): # 30 days """Generate refresh token for extended sessions""" payload = { 'user_id': self.id, 'type': 'refresh', 'exp': datetime.utcnow() + timedelta(seconds=expires_in), 'iat': datetime.utcnow() } return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') @staticmethod def verify_auth_token(token): """Verify JWT authentication token""" try: payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) if payload.get('type') == 'refresh': return None # Refresh tokens cannot be used for authentication return User.query.get(payload['user_id']) except jwt.ExpiredSignatureError: return None except jwt.InvalidTokenError: return None @staticmethod def verify_refresh_token(token): """Verify refresh token and return user""" try: payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) if payload.get('type') != 'refresh': return None return User.query.get(payload['user_id']) except jwt.ExpiredSignatureError: return None except jwt.InvalidTokenError: return None def generate_password_reset_token(self): """Generate password reset token""" self.password_reset_token = secrets.token_urlsafe(32) self.password_reset_expires = datetime.utcnow() + timedelta(hours=1) return self.password_reset_token def verify_password_reset_token(self, token): """Verify password reset token""" if (self.password_reset_token == token and self.password_reset_expires and datetime.utcnow() < self.password_reset_expires): return True return False def reset_password(self, new_password, token): """Reset password using reset token""" if not self.verify_password_reset_token(token): return False self.set_password(new_password) self.password_reset_token = None self.password_reset_expires = None return True def verify_email(self, token): """Verify email using verification token""" if self.email_verification_token == token: self.email_verified = True self.email_verification_token = None return True return False def can_create_vpn_client(self): """Check if user can create additional VPN clients""" active_clients = len([c for c in self.vpn_clients if c.is_active]) if self.subscription_type == 'free': return active_clients < 1 elif self.subscription_type == 'premium': return active_clients < 5 elif self.subscription_type == 'enterprise': return active_clients < 50 return False def get_active_sessions_count(self): """Get count of active VPN sessions""" return len([s for s in self.vpn_sessions if s.disconnected_at is None]) def can_connect_vpn(self): """Check if user can establish new VPN connections""" active_sessions = self.get_active_sessions_count() return active_sessions < self.max_concurrent_connections def get_bandwidth_usage_today(self): """Get bandwidth usage for today""" today = datetime.utcnow().date() today_sessions = [s for s in self.vpn_sessions if s.connected_at and s.connected_at.date() == today] total_bytes = sum(s.bytes_received + s.bytes_sent for s in today_sessions) return total_bytes def is_subscription_active(self): """Check if subscription is active""" if self.subscription_type == 'free': return True return (self.subscription_expires and datetime.utcnow() < self.subscription_expires) def to_dict(self, include_sensitive=False): """Convert user to dictionary""" data = { 'id': self.id, 'username': self.username, 'email': self.email, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_login': self.last_login.isoformat() if self.last_login else None, 'is_active': self.is_active, 'subscription_type': self.subscription_type, 'subscription_expires': self.subscription_expires.isoformat() if self.subscription_expires else None, 'max_concurrent_connections': self.max_concurrent_connections, 'bandwidth_limit_mbps': self.bandwidth_limit_mbps, 'email_verified': self.email_verified, 'two_factor_enabled': self.two_factor_enabled, 'is_subscription_active': self.is_subscription_active(), 'active_vpn_clients': len([c for c in self.vpn_clients if c.is_active]), 'active_sessions': self.get_active_sessions_count(), 'can_create_vpn_client': self.can_create_vpn_client(), 'can_connect_vpn': self.can_connect_vpn() } if include_sensitive and (self.is_admin or include_sensitive == 'self'): data.update({ 'is_admin': self.is_admin, 'failed_login_attempts': self.failed_login_attempts, 'account_locked': self.is_account_locked(), 'bandwidth_usage_today': self.get_bandwidth_usage_today() }) return data class VPNClient(db.Model): """VPN Client configuration and management""" __tablename__ = 'vpn_clients' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) client_name = db.Column(db.String(100), nullable=False) protocol = db.Column(db.String(20), nullable=False) # openvpn, ikev2, wireguard certificate_serial = db.Column(db.String(50), unique=True) private_key_path = db.Column(db.String(255)) certificate_path = db.Column(db.String(255)) config_file_path = db.Column(db.String(255)) created_at = db.Column(db.DateTime, default=datetime.utcnow) last_connected = db.Column(db.DateTime) is_active = db.Column(db.Boolean, default=True) device_type = db.Column(db.String(50)) # windows, macos, linux, ios, android public_key = db.Column(db.Text) # For WireGuard # Relationships sessions = db.relationship('VPNSession', backref='vpn_client', lazy=True) def __init__(self, user_id, client_name, protocol, device_type=None): self.user_id = user_id self.client_name = client_name self.protocol = protocol self.device_type = device_type def get_active_sessions(self): """Get active sessions for this client""" return [s for s in self.sessions if s.disconnected_at is None] def get_total_bandwidth_usage(self): """Get total bandwidth usage for this client""" return sum(s.bytes_received + s.bytes_sent for s in self.sessions) def to_dict(self): """Convert VPN client to dictionary""" return { 'id': self.id, 'client_name': self.client_name, 'protocol': self.protocol, 'device_type': self.device_type, 'created_at': self.created_at.isoformat() if self.created_at else None, 'last_connected': self.last_connected.isoformat() if self.last_connected else None, 'is_active': self.is_active, 'certificate_serial': self.certificate_serial, 'active_sessions': len(self.get_active_sessions()), 'total_bandwidth_usage': self.get_total_bandwidth_usage() } class VPNSession(db.Model): """VPN Session tracking""" __tablename__ = 'vpn_sessions' id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) client_id = db.Column(db.Integer, db.ForeignKey('vpn_clients.id'), nullable=False) server_protocol = db.Column(db.String(20), nullable=False) client_ip = db.Column(db.String(15)) server_ip = db.Column(db.String(15)) client_real_ip = db.Column(db.String(45)) # Support IPv6 connected_at = db.Column(db.DateTime, default=datetime.utcnow) disconnected_at = db.Column(db.DateTime) bytes_received = db.Column(db.BigInteger, default=0) bytes_sent = db.Column(db.BigInteger, default=0) session_duration = db.Column(db.Integer) # in seconds disconnect_reason = db.Column(db.String(100)) def __init__(self, user_id, client_id, server_protocol, client_ip=None, server_ip=None, client_real_ip=None): self.user_id = user_id self.client_id = client_id self.server_protocol = server_protocol self.client_ip = client_ip self.server_ip = server_ip self.client_real_ip = client_real_ip def disconnect(self, reason=None): """Mark session as disconnected""" self.disconnected_at = datetime.utcnow() self.disconnect_reason = reason if self.connected_at: self.session_duration = int((self.disconnected_at - self.connected_at).total_seconds()) def is_active(self): """Check if session is active""" return self.disconnected_at is None def get_duration(self): """Get session duration in seconds""" if self.disconnected_at: return self.session_duration elif self.connected_at: return int((datetime.utcnow() - self.connected_at).total_seconds()) return 0 def to_dict(self): """Convert VPN session to dictionary""" return { 'id': self.id, 'client_id': self.client_id, 'server_protocol': self.server_protocol, 'client_ip': self.client_ip, 'server_ip': self.server_ip, 'client_real_ip': self.client_real_ip, 'connected_at': self.connected_at.isoformat() if self.connected_at else None, 'disconnected_at': self.disconnected_at.isoformat() if self.disconnected_at else None, 'bytes_received': self.bytes_received, 'bytes_sent': self.bytes_sent, 'session_duration': self.get_duration(), 'disconnect_reason': self.disconnect_reason, 'is_active': self.is_active() } class ServerConfiguration(db.Model): """VPN Server configuration management""" __tablename__ = 'server_configurations' id = db.Column(db.Integer, primary_key=True) protocol = db.Column(db.String(20), nullable=False) server_name = db.Column(db.String(100), nullable=False) listen_port = db.Column(db.Integer, nullable=False) network_cidr = db.Column(db.String(18), nullable=False) dns_servers = db.Column(db.Text) # JSON string routes = db.Column(db.Text) # JSON string is_active = db.Column(db.Boolean, default=True) created_at = db.Column(db.DateTime, default=datetime.utcnow) updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) max_clients = db.Column(db.Integer, default=100) def __init__(self, protocol, server_name, listen_port, network_cidr): self.protocol = protocol self.server_name = server_name self.listen_port = listen_port self.network_cidr = network_cidr def to_dict(self): """Convert server configuration to dictionary""" return { 'id': self.id, 'protocol': self.protocol, 'server_name': self.server_name, 'listen_port': self.listen_port, 'network_cidr': self.network_cidr, 'dns_servers': self.dns_servers, 'routes': self.routes, 'is_active': self.is_active, 'created_at': self.created_at.isoformat() if self.created_at else None, 'updated_at': self.updated_at.isoformat() if self.updated_at else None, 'max_clients': self.max_clients }