Spaces:
Paused
Paused
| """ | |
| Enhanced User Model with Authentication and VPN Client Management | |
| This module provides comprehensive user management with security features, | |
| VPN client management, and session tracking capabilities. | |
| """ | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from datetime import datetime, timedelta | |
| import secrets | |
| import jwt | |
| import re | |
| from flask import current_app | |
| from .user import db | |
| class User(db.Model): | |
| """Enhanced User model with authentication and VPN management""" | |
| __tablename__ = 'users' | |
| id = db.Column(db.Integer, primary_key=True) | |
| username = db.Column(db.String(80), unique=True, nullable=False, index=True) | |
| email = db.Column(db.String(120), unique=True, nullable=False, index=True) | |
| password_hash = db.Column(db.String(255), nullable=False) | |
| salt = db.Column(db.String(32), nullable=False) | |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| last_login = db.Column(db.DateTime) | |
| is_active = db.Column(db.Boolean, default=True) | |
| is_admin = db.Column(db.Boolean, default=False) | |
| subscription_type = db.Column(db.String(20), default='free') | |
| subscription_expires = db.Column(db.DateTime) | |
| max_concurrent_connections = db.Column(db.Integer, default=1) | |
| bandwidth_limit_mbps = db.Column(db.Integer, default=10) | |
| email_verified = db.Column(db.Boolean, default=False) | |
| email_verification_token = db.Column(db.String(64)) | |
| two_factor_enabled = db.Column(db.Boolean, default=False) | |
| two_factor_secret = db.Column(db.String(32)) | |
| password_reset_token = db.Column(db.String(64)) | |
| password_reset_expires = db.Column(db.DateTime) | |
| failed_login_attempts = db.Column(db.Integer, default=0) | |
| account_locked_until = db.Column(db.DateTime) | |
| # Relationships | |
| vpn_clients = db.relationship('VPNClient', backref='user', lazy=True, cascade='all, delete-orphan') | |
| vpn_sessions = db.relationship('VPNSession', backref='user', lazy=True) | |
| def __init__(self, username, email, password=None): | |
| self.username = username | |
| self.email = email | |
| if password: | |
| self.set_password(password) | |
| self.email_verification_token = secrets.token_urlsafe(32) | |
| def set_password(self, password): | |
| """Set password with secure hashing and salt""" | |
| if not self.validate_password_strength(password): | |
| raise ValueError("Password does not meet security requirements") | |
| self.salt = secrets.token_hex(16) | |
| self.password_hash = generate_password_hash(password + self.salt) | |
| self.failed_login_attempts = 0 | |
| self.account_locked_until = None | |
| def check_password(self, password): | |
| """Verify password against hash""" | |
| if self.is_account_locked(): | |
| return False | |
| is_valid = check_password_hash(self.password_hash, password + self.salt) | |
| if is_valid: | |
| self.failed_login_attempts = 0 | |
| self.last_login = datetime.utcnow() | |
| else: | |
| self.failed_login_attempts += 1 | |
| if self.failed_login_attempts >= 5: | |
| self.account_locked_until = datetime.utcnow() + timedelta(minutes=30) | |
| return is_valid | |
| def is_account_locked(self): | |
| """Check if account is locked due to failed login attempts""" | |
| if self.account_locked_until and datetime.utcnow() < self.account_locked_until: | |
| return True | |
| elif self.account_locked_until and datetime.utcnow() >= self.account_locked_until: | |
| # Unlock account | |
| self.account_locked_until = None | |
| self.failed_login_attempts = 0 | |
| return False | |
| def validate_password_strength(password): | |
| """Validate password meets security requirements""" | |
| if len(password) < 8: | |
| return False | |
| if not re.search(r'[A-Z]', password): | |
| return False | |
| if not re.search(r'[a-z]', password): | |
| return False | |
| if not re.search(r'\d', password): | |
| return False | |
| if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password): | |
| return False | |
| return True | |
| def validate_email(email): | |
| """Validate email format""" | |
| pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' | |
| return re.match(pattern, email) is not None | |
| def validate_username(username): | |
| """Validate username format""" | |
| if len(username) < 3 or len(username) > 80: | |
| return False | |
| if not re.match(r'^[a-zA-Z0-9_-]+$', username): | |
| return False | |
| return True | |
| def generate_auth_token(self, expires_in=3600): | |
| """Generate JWT authentication token""" | |
| payload = { | |
| 'user_id': self.id, | |
| 'username': self.username, | |
| 'email': self.email, | |
| 'subscription_type': self.subscription_type, | |
| 'is_admin': self.is_admin, | |
| 'exp': datetime.utcnow() + timedelta(seconds=expires_in), | |
| 'iat': datetime.utcnow() | |
| } | |
| return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') | |
| def generate_refresh_token(self, expires_in=2592000): # 30 days | |
| """Generate refresh token for extended sessions""" | |
| payload = { | |
| 'user_id': self.id, | |
| 'type': 'refresh', | |
| 'exp': datetime.utcnow() + timedelta(seconds=expires_in), | |
| 'iat': datetime.utcnow() | |
| } | |
| return jwt.encode(payload, current_app.config['SECRET_KEY'], algorithm='HS256') | |
| def verify_auth_token(token): | |
| """Verify JWT authentication token""" | |
| try: | |
| payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) | |
| if payload.get('type') == 'refresh': | |
| return None # Refresh tokens cannot be used for authentication | |
| return User.query.get(payload['user_id']) | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| def verify_refresh_token(token): | |
| """Verify refresh token and return user""" | |
| try: | |
| payload = jwt.decode(token, current_app.config['SECRET_KEY'], algorithms=['HS256']) | |
| if payload.get('type') != 'refresh': | |
| return None | |
| return User.query.get(payload['user_id']) | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| def generate_password_reset_token(self): | |
| """Generate password reset token""" | |
| self.password_reset_token = secrets.token_urlsafe(32) | |
| self.password_reset_expires = datetime.utcnow() + timedelta(hours=1) | |
| return self.password_reset_token | |
| def verify_password_reset_token(self, token): | |
| """Verify password reset token""" | |
| if (self.password_reset_token == token and | |
| self.password_reset_expires and | |
| datetime.utcnow() < self.password_reset_expires): | |
| return True | |
| return False | |
| def reset_password(self, new_password, token): | |
| """Reset password using reset token""" | |
| if not self.verify_password_reset_token(token): | |
| return False | |
| self.set_password(new_password) | |
| self.password_reset_token = None | |
| self.password_reset_expires = None | |
| return True | |
| def verify_email(self, token): | |
| """Verify email using verification token""" | |
| if self.email_verification_token == token: | |
| self.email_verified = True | |
| self.email_verification_token = None | |
| return True | |
| return False | |
| def can_create_vpn_client(self): | |
| """Check if user can create additional VPN clients""" | |
| active_clients = len([c for c in self.vpn_clients if c.is_active]) | |
| if self.subscription_type == 'free': | |
| return active_clients < 1 | |
| elif self.subscription_type == 'premium': | |
| return active_clients < 5 | |
| elif self.subscription_type == 'enterprise': | |
| return active_clients < 50 | |
| return False | |
| def get_active_sessions_count(self): | |
| """Get count of active VPN sessions""" | |
| return len([s for s in self.vpn_sessions if s.disconnected_at is None]) | |
| def can_connect_vpn(self): | |
| """Check if user can establish new VPN connections""" | |
| active_sessions = self.get_active_sessions_count() | |
| return active_sessions < self.max_concurrent_connections | |
| def get_bandwidth_usage_today(self): | |
| """Get bandwidth usage for today""" | |
| today = datetime.utcnow().date() | |
| today_sessions = [s for s in self.vpn_sessions | |
| if s.connected_at and s.connected_at.date() == today] | |
| total_bytes = sum(s.bytes_received + s.bytes_sent for s in today_sessions) | |
| return total_bytes | |
| def is_subscription_active(self): | |
| """Check if subscription is active""" | |
| if self.subscription_type == 'free': | |
| return True | |
| return (self.subscription_expires and | |
| datetime.utcnow() < self.subscription_expires) | |
| def to_dict(self, include_sensitive=False): | |
| """Convert user to dictionary""" | |
| data = { | |
| 'id': self.id, | |
| 'username': self.username, | |
| 'email': self.email, | |
| 'created_at': self.created_at.isoformat() if self.created_at else None, | |
| 'last_login': self.last_login.isoformat() if self.last_login else None, | |
| 'is_active': self.is_active, | |
| 'subscription_type': self.subscription_type, | |
| 'subscription_expires': self.subscription_expires.isoformat() if self.subscription_expires else None, | |
| 'max_concurrent_connections': self.max_concurrent_connections, | |
| 'bandwidth_limit_mbps': self.bandwidth_limit_mbps, | |
| 'email_verified': self.email_verified, | |
| 'two_factor_enabled': self.two_factor_enabled, | |
| 'is_subscription_active': self.is_subscription_active(), | |
| 'active_vpn_clients': len([c for c in self.vpn_clients if c.is_active]), | |
| 'active_sessions': self.get_active_sessions_count(), | |
| 'can_create_vpn_client': self.can_create_vpn_client(), | |
| 'can_connect_vpn': self.can_connect_vpn() | |
| } | |
| if include_sensitive and (self.is_admin or include_sensitive == 'self'): | |
| data.update({ | |
| 'is_admin': self.is_admin, | |
| 'failed_login_attempts': self.failed_login_attempts, | |
| 'account_locked': self.is_account_locked(), | |
| 'bandwidth_usage_today': self.get_bandwidth_usage_today() | |
| }) | |
| return data | |
| class VPNClient(db.Model): | |
| """VPN Client configuration and management""" | |
| __tablename__ = 'vpn_clients' | |
| id = db.Column(db.Integer, primary_key=True) | |
| user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) | |
| client_name = db.Column(db.String(100), nullable=False) | |
| protocol = db.Column(db.String(20), nullable=False) # openvpn, ikev2, wireguard | |
| certificate_serial = db.Column(db.String(50), unique=True) | |
| private_key_path = db.Column(db.String(255)) | |
| certificate_path = db.Column(db.String(255)) | |
| config_file_path = db.Column(db.String(255)) | |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| last_connected = db.Column(db.DateTime) | |
| is_active = db.Column(db.Boolean, default=True) | |
| device_type = db.Column(db.String(50)) # windows, macos, linux, ios, android | |
| public_key = db.Column(db.Text) # For WireGuard | |
| # Relationships | |
| sessions = db.relationship('VPNSession', backref='vpn_client', lazy=True) | |
| def __init__(self, user_id, client_name, protocol, device_type=None): | |
| self.user_id = user_id | |
| self.client_name = client_name | |
| self.protocol = protocol | |
| self.device_type = device_type | |
| def get_active_sessions(self): | |
| """Get active sessions for this client""" | |
| return [s for s in self.sessions if s.disconnected_at is None] | |
| def get_total_bandwidth_usage(self): | |
| """Get total bandwidth usage for this client""" | |
| return sum(s.bytes_received + s.bytes_sent for s in self.sessions) | |
| def to_dict(self): | |
| """Convert VPN client to dictionary""" | |
| return { | |
| 'id': self.id, | |
| 'client_name': self.client_name, | |
| 'protocol': self.protocol, | |
| 'device_type': self.device_type, | |
| 'created_at': self.created_at.isoformat() if self.created_at else None, | |
| 'last_connected': self.last_connected.isoformat() if self.last_connected else None, | |
| 'is_active': self.is_active, | |
| 'certificate_serial': self.certificate_serial, | |
| 'active_sessions': len(self.get_active_sessions()), | |
| 'total_bandwidth_usage': self.get_total_bandwidth_usage() | |
| } | |
| class VPNSession(db.Model): | |
| """VPN Session tracking""" | |
| __tablename__ = 'vpn_sessions' | |
| id = db.Column(db.Integer, primary_key=True) | |
| user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) | |
| client_id = db.Column(db.Integer, db.ForeignKey('vpn_clients.id'), nullable=False) | |
| server_protocol = db.Column(db.String(20), nullable=False) | |
| client_ip = db.Column(db.String(15)) | |
| server_ip = db.Column(db.String(15)) | |
| client_real_ip = db.Column(db.String(45)) # Support IPv6 | |
| connected_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| disconnected_at = db.Column(db.DateTime) | |
| bytes_received = db.Column(db.BigInteger, default=0) | |
| bytes_sent = db.Column(db.BigInteger, default=0) | |
| session_duration = db.Column(db.Integer) # in seconds | |
| disconnect_reason = db.Column(db.String(100)) | |
| def __init__(self, user_id, client_id, server_protocol, client_ip=None, server_ip=None, client_real_ip=None): | |
| self.user_id = user_id | |
| self.client_id = client_id | |
| self.server_protocol = server_protocol | |
| self.client_ip = client_ip | |
| self.server_ip = server_ip | |
| self.client_real_ip = client_real_ip | |
| def disconnect(self, reason=None): | |
| """Mark session as disconnected""" | |
| self.disconnected_at = datetime.utcnow() | |
| self.disconnect_reason = reason | |
| if self.connected_at: | |
| self.session_duration = int((self.disconnected_at - self.connected_at).total_seconds()) | |
| def is_active(self): | |
| """Check if session is active""" | |
| return self.disconnected_at is None | |
| def get_duration(self): | |
| """Get session duration in seconds""" | |
| if self.disconnected_at: | |
| return self.session_duration | |
| elif self.connected_at: | |
| return int((datetime.utcnow() - self.connected_at).total_seconds()) | |
| return 0 | |
| def to_dict(self): | |
| """Convert VPN session to dictionary""" | |
| return { | |
| 'id': self.id, | |
| 'client_id': self.client_id, | |
| 'server_protocol': self.server_protocol, | |
| 'client_ip': self.client_ip, | |
| 'server_ip': self.server_ip, | |
| 'client_real_ip': self.client_real_ip, | |
| 'connected_at': self.connected_at.isoformat() if self.connected_at else None, | |
| 'disconnected_at': self.disconnected_at.isoformat() if self.disconnected_at else None, | |
| 'bytes_received': self.bytes_received, | |
| 'bytes_sent': self.bytes_sent, | |
| 'session_duration': self.get_duration(), | |
| 'disconnect_reason': self.disconnect_reason, | |
| 'is_active': self.is_active() | |
| } | |
| class ServerConfiguration(db.Model): | |
| """VPN Server configuration management""" | |
| __tablename__ = 'server_configurations' | |
| id = db.Column(db.Integer, primary_key=True) | |
| protocol = db.Column(db.String(20), nullable=False) | |
| server_name = db.Column(db.String(100), nullable=False) | |
| listen_port = db.Column(db.Integer, nullable=False) | |
| network_cidr = db.Column(db.String(18), nullable=False) | |
| dns_servers = db.Column(db.Text) # JSON string | |
| routes = db.Column(db.Text) # JSON string | |
| is_active = db.Column(db.Boolean, default=True) | |
| created_at = db.Column(db.DateTime, default=datetime.utcnow) | |
| updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| max_clients = db.Column(db.Integer, default=100) | |
| def __init__(self, protocol, server_name, listen_port, network_cidr): | |
| self.protocol = protocol | |
| self.server_name = server_name | |
| self.listen_port = listen_port | |
| self.network_cidr = network_cidr | |
| def to_dict(self): | |
| """Convert server configuration to dictionary""" | |
| return { | |
| 'id': self.id, | |
| 'protocol': self.protocol, | |
| 'server_name': self.server_name, | |
| 'listen_port': self.listen_port, | |
| 'network_cidr': self.network_cidr, | |
| 'dns_servers': self.dns_servers, | |
| 'routes': self.routes, | |
| 'is_active': self.is_active, | |
| 'created_at': self.created_at.isoformat() if self.created_at else None, | |
| 'updated_at': self.updated_at.isoformat() if self.updated_at else None, | |
| 'max_clients': self.max_clients | |
| } | |