| """ |
| Enhanced User Model with Authentication and VPN Client Management |
| |
| This module provides comprehensive user management with security features, |
| VPN client management, and session tracking capabilities. |
| """ |
|
|
|
|
| from werkzeug.security import generate_password_hash, check_password_hash |
| from datetime import datetime, timedelta |
| import secrets |
| import jwt |
| import re |
| from flask import current_app |
|
|
| from .user import db |
|
|
| class User(db.Model): |
| """Enhanced User model with authentication and VPN management""" |
| __tablename__ = 'users' |
| |
| id = db.Column(db.Integer, primary_key=True) |
| username = db.Column(db.String(80), unique=True, nullable=False, index=True) |
| email = db.Column(db.String(120), unique=True, nullable=False, index=True) |
| password_hash = db.Column(db.String(255), nullable=False) |
| salt = db.Column(db.String(32), nullable=False) |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) |
| last_login = db.Column(db.DateTime) |
| is_active = db.Column(db.Boolean, default=True) |
| is_admin = db.Column(db.Boolean, default=False) |
| subscription_type = db.Column(db.String(20), default='free') |
| subscription_expires = db.Column(db.DateTime) |
| max_concurrent_connections = db.Column(db.Integer, default=1) |
| bandwidth_limit_mbps = db.Column(db.Integer, default=10) |
| email_verified = db.Column(db.Boolean, default=False) |
| email_verification_token = db.Column(db.String(64)) |
| two_factor_enabled = db.Column(db.Boolean, default=False) |
| two_factor_secret = db.Column(db.String(32)) |
| password_reset_token = db.Column(db.String(64)) |
| password_reset_expires = db.Column(db.DateTime) |
| failed_login_attempts = db.Column(db.Integer, default=0) |
| account_locked_until = db.Column(db.DateTime) |
| |
| |
| vpn_clients = db.relationship('VPNClient', backref='user', lazy=True, cascade='all, delete-orphan') |
| vpn_sessions = db.relationship('VPNSession', backref='user', lazy=True) |
| |
| def __init__(self, username, email, password=None): |
| self.username = username |
| self.email = email |
| if password: |
| self.set_password(password) |
| self.email_verification_token = secrets.token_urlsafe(32) |
| |
| def set_password(self, password): |
| """Set password with secure hashing and salt""" |
| if not self.validate_password_strength(password): |
| raise ValueError("Password does not meet security requirements") |
| |
| self.salt = secrets.token_hex(16) |
| self.password_hash = generate_password_hash(password + self.salt) |
| self.failed_login_attempts = 0 |
| self.account_locked_until = None |
| |
| def check_password(self, password): |
| """Verify password against hash""" |
| if self.is_account_locked(): |
| return False |
| |
| is_valid = check_password_hash(self.password_hash, password + self.salt) |
| |
| if is_valid: |
| self.failed_login_attempts = 0 |
| self.last_login = datetime.utcnow() |
| else: |
| self.failed_login_attempts += 1 |
| if self.failed_login_attempts >= 5: |
| self.account_locked_until = datetime.utcnow() + timedelta(minutes=30) |
| |
| return is_valid |
| |
| def is_account_locked(self): |
| """Check if account is locked due to failed login attempts""" |
| if self.account_locked_until and datetime.utcnow() < self.account_locked_until: |
| return True |
| elif self.account_locked_until and datetime.utcnow() >= self.account_locked_until: |
| |
| self.account_locked_until = None |
| self.failed_login_attempts = 0 |
| return False |
| |
| @staticmethod |
| def validate_password_strength(password): |
| """Validate password meets security requirements""" |
| if len(password) < 8: |
| return False |
| if not re.search(r'[A-Z]', password): |
| return False |
| if not re.search(r'[a-z]', password): |
| return False |
| if not re.search(r'\d', password): |
| return False |
| if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password): |
| return False |
| return True |
| |
| @staticmethod |
| def validate_email(email): |
| """Validate email format""" |
| pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' |
| return re.match(pattern, email) is not None |
| |
| @staticmethod |
| def validate_username(username): |
| """Validate username format""" |
| if len(username) < 3 or len(username) > 80: |
| return False |
| if not re.match(r'^[a-zA-Z0-9_-]+$', username): |
| return False |
| return True |
| |
| def generate_auth_token(self, expires_in=3600): |
| """Generate JWT authentication token""" |
| payload = { |
| 'user_id': self.id, |
| 'username': self.username, |
| 'email': self.email, |
| 'subscription_type': self.subscription_type, |
| 'is_admin': self.is_admin, |
| 'exp': datetime.utcnow() + timedelta(seconds=expires_in), |
| 'iat': datetime.utcnow() |
| } |
| return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') |
| |
| def generate_refresh_token(self, expires_in=2592000): |
| """Generate refresh token for extended sessions""" |
| payload = { |
| 'user_id': self.id, |
| 'type': 'refresh', |
| 'exp': datetime.utcnow() + timedelta(seconds=expires_in), |
| 'iat': datetime.utcnow() |
| } |
| return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') |
| |
| @staticmethod |
| def verify_auth_token(token): |
| """Verify JWT authentication token""" |
| try: |
| payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) |
| if payload.get('type') == 'refresh': |
| return None |
| return User.query.get(payload['user_id']) |
| except jwt.ExpiredSignatureError: |
| return None |
| except jwt.InvalidTokenError: |
| return None |
| |
| @staticmethod |
| def verify_refresh_token(token): |
| """Verify refresh token and return user""" |
| try: |
| payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) |
| if payload.get('type') != 'refresh': |
| return None |
| return User.query.get(payload['user_id']) |
| except jwt.ExpiredSignatureError: |
| return None |
| except jwt.InvalidTokenError: |
| return None |
| |
| def generate_password_reset_token(self): |
| """Generate password reset token""" |
| self.password_reset_token = secrets.token_urlsafe(32) |
| self.password_reset_expires = datetime.utcnow() + timedelta(hours=1) |
| return self.password_reset_token |
| |
| def verify_password_reset_token(self, token): |
| """Verify password reset token""" |
| if (self.password_reset_token == token and |
| self.password_reset_expires and |
| datetime.utcnow() < self.password_reset_expires): |
| return True |
| return False |
| |
| def reset_password(self, new_password, token): |
| """Reset password using reset token""" |
| if not self.verify_password_reset_token(token): |
| return False |
| |
| self.set_password(new_password) |
| self.password_reset_token = None |
| self.password_reset_expires = None |
| return True |
| |
| def verify_email(self, token): |
| """Verify email using verification token""" |
| if self.email_verification_token == token: |
| self.email_verified = True |
| self.email_verification_token = None |
| return True |
| return False |
| |
| def can_create_vpn_client(self): |
| """Check if user can create additional VPN clients""" |
| active_clients = len([c for c in self.vpn_clients if c.is_active]) |
| |
| if self.subscription_type == 'free': |
| return active_clients < 1 |
| elif self.subscription_type == 'premium': |
| return active_clients < 5 |
| elif self.subscription_type == 'enterprise': |
| return active_clients < 50 |
| |
| return False |
| |
| def get_active_sessions_count(self): |
| """Get count of active VPN sessions""" |
| return len([s for s in self.vpn_sessions if s.disconnected_at is None]) |
| |
| def can_connect_vpn(self): |
| """Check if user can establish new VPN connections""" |
| active_sessions = self.get_active_sessions_count() |
| return active_sessions < self.max_concurrent_connections |
| |
| def get_bandwidth_usage_today(self): |
| """Get bandwidth usage for today""" |
| today = datetime.utcnow().date() |
| today_sessions = [s for s in self.vpn_sessions |
| if s.connected_at and s.connected_at.date() == today] |
| |
| total_bytes = sum(s.bytes_received + s.bytes_sent for s in today_sessions) |
| return total_bytes |
| |
| def is_subscription_active(self): |
| """Check if subscription is active""" |
| if self.subscription_type == 'free': |
| return True |
| |
| return (self.subscription_expires and |
| datetime.utcnow() < self.subscription_expires) |
| |
| def to_dict(self, include_sensitive=False): |
| """Convert user to dictionary""" |
| data = { |
| 'id': self.id, |
| 'username': self.username, |
| 'email': self.email, |
| 'created_at': self.created_at.isoformat() if self.created_at else None, |
| 'last_login': self.last_login.isoformat() if self.last_login else None, |
| 'is_active': self.is_active, |
| 'subscription_type': self.subscription_type, |
| 'subscription_expires': self.subscription_expires.isoformat() if self.subscription_expires else None, |
| 'max_concurrent_connections': self.max_concurrent_connections, |
| 'bandwidth_limit_mbps': self.bandwidth_limit_mbps, |
| 'email_verified': self.email_verified, |
| 'two_factor_enabled': self.two_factor_enabled, |
| 'is_subscription_active': self.is_subscription_active(), |
| 'active_vpn_clients': len([c for c in self.vpn_clients if c.is_active]), |
| 'active_sessions': self.get_active_sessions_count(), |
| 'can_create_vpn_client': self.can_create_vpn_client(), |
| 'can_connect_vpn': self.can_connect_vpn() |
| } |
| |
| if include_sensitive and (self.is_admin or include_sensitive == 'self'): |
| data.update({ |
| 'is_admin': self.is_admin, |
| 'failed_login_attempts': self.failed_login_attempts, |
| 'account_locked': self.is_account_locked(), |
| 'bandwidth_usage_today': self.get_bandwidth_usage_today() |
| }) |
| |
| return data |
|
|
|
|
| class VPNClient(db.Model): |
| """VPN Client configuration and management""" |
| __tablename__ = 'vpn_clients' |
| |
| id = db.Column(db.Integer, primary_key=True) |
| user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) |
| client_name = db.Column(db.String(100), nullable=False) |
| protocol = db.Column(db.String(20), nullable=False) |
| certificate_serial = db.Column(db.String(50), unique=True) |
| private_key_path = db.Column(db.String(255)) |
| certificate_path = db.Column(db.String(255)) |
| config_file_path = db.Column(db.String(255)) |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) |
| last_connected = db.Column(db.DateTime) |
| is_active = db.Column(db.Boolean, default=True) |
| device_type = db.Column(db.String(50)) |
| public_key = db.Column(db.Text) |
| |
| |
| sessions = db.relationship('VPNSession', backref='vpn_client', lazy=True) |
| |
| def __init__(self, user_id, client_name, protocol, device_type=None): |
| self.user_id = user_id |
| self.client_name = client_name |
| self.protocol = protocol |
| self.device_type = device_type |
| |
| def get_active_sessions(self): |
| """Get active sessions for this client""" |
| return [s for s in self.sessions if s.disconnected_at is None] |
| |
| def get_total_bandwidth_usage(self): |
| """Get total bandwidth usage for this client""" |
| return sum(s.bytes_received + s.bytes_sent for s in self.sessions) |
| |
| def to_dict(self): |
| """Convert VPN client to dictionary""" |
| return { |
| 'id': self.id, |
| 'client_name': self.client_name, |
| 'protocol': self.protocol, |
| 'device_type': self.device_type, |
| 'created_at': self.created_at.isoformat() if self.created_at else None, |
| 'last_connected': self.last_connected.isoformat() if self.last_connected else None, |
| 'is_active': self.is_active, |
| 'certificate_serial': self.certificate_serial, |
| 'active_sessions': len(self.get_active_sessions()), |
| 'total_bandwidth_usage': self.get_total_bandwidth_usage() |
| } |
|
|
|
|
| class VPNSession(db.Model): |
| """VPN Session tracking""" |
| __tablename__ = 'vpn_sessions' |
| |
| id = db.Column(db.Integer, primary_key=True) |
| user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) |
| client_id = db.Column(db.Integer, db.ForeignKey('vpn_clients.id'), nullable=False) |
| server_protocol = db.Column(db.String(20), nullable=False) |
| client_ip = db.Column(db.String(15)) |
| server_ip = db.Column(db.String(15)) |
| client_real_ip = db.Column(db.String(45)) |
| connected_at = db.Column(db.DateTime, default=datetime.utcnow) |
| disconnected_at = db.Column(db.DateTime) |
| bytes_received = db.Column(db.BigInteger, default=0) |
| bytes_sent = db.Column(db.BigInteger, default=0) |
| session_duration = db.Column(db.Integer) |
| disconnect_reason = db.Column(db.String(100)) |
| |
| def __init__(self, user_id, client_id, server_protocol, client_ip=None, server_ip=None, client_real_ip=None): |
| self.user_id = user_id |
| self.client_id = client_id |
| self.server_protocol = server_protocol |
| self.client_ip = client_ip |
| self.server_ip = server_ip |
| self.client_real_ip = client_real_ip |
| |
| def disconnect(self, reason=None): |
| """Mark session as disconnected""" |
| self.disconnected_at = datetime.utcnow() |
| self.disconnect_reason = reason |
| if self.connected_at: |
| self.session_duration = int((self.disconnected_at - self.connected_at).total_seconds()) |
| |
| def is_active(self): |
| """Check if session is active""" |
| return self.disconnected_at is None |
| |
| def get_duration(self): |
| """Get session duration in seconds""" |
| if self.disconnected_at: |
| return self.session_duration |
| elif self.connected_at: |
| return int((datetime.utcnow() - self.connected_at).total_seconds()) |
| return 0 |
| |
| def to_dict(self): |
| """Convert VPN session to dictionary""" |
| return { |
| 'id': self.id, |
| 'client_id': self.client_id, |
| 'server_protocol': self.server_protocol, |
| 'client_ip': self.client_ip, |
| 'server_ip': self.server_ip, |
| 'client_real_ip': self.client_real_ip, |
| 'connected_at': self.connected_at.isoformat() if self.connected_at else None, |
| 'disconnected_at': self.disconnected_at.isoformat() if self.disconnected_at else None, |
| 'bytes_received': self.bytes_received, |
| 'bytes_sent': self.bytes_sent, |
| 'session_duration': self.get_duration(), |
| 'disconnect_reason': self.disconnect_reason, |
| 'is_active': self.is_active() |
| } |
|
|
|
|
| class ServerConfiguration(db.Model): |
| """VPN Server configuration management""" |
| __tablename__ = 'server_configurations' |
| |
| id = db.Column(db.Integer, primary_key=True) |
| protocol = db.Column(db.String(20), nullable=False) |
| server_name = db.Column(db.String(100), nullable=False) |
| listen_port = db.Column(db.Integer, nullable=False) |
| network_cidr = db.Column(db.String(18), nullable=False) |
| dns_servers = db.Column(db.Text) |
| routes = db.Column(db.Text) |
| is_active = db.Column(db.Boolean, default=True) |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) |
| updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) |
| max_clients = db.Column(db.Integer, default=100) |
| |
| def __init__(self, protocol, server_name, listen_port, network_cidr): |
| self.protocol = protocol |
| self.server_name = server_name |
| self.listen_port = listen_port |
| self.network_cidr = network_cidr |
| |
| def to_dict(self): |
| """Convert server configuration to dictionary""" |
| return { |
| 'id': self.id, |
| 'protocol': self.protocol, |
| 'server_name': self.server_name, |
| 'listen_port': self.listen_port, |
| 'network_cidr': self.network_cidr, |
| 'dns_servers': self.dns_servers, |
| 'routes': self.routes, |
| 'is_active': self.is_active, |
| 'created_at': self.created_at.isoformat() if self.created_at else None, |
| 'updated_at': self.updated_at.isoformat() if self.updated_at else None, |
| 'max_clients': self.max_clients |
| } |
|
|
|
|