Spaces:
Sleeping
Sleeping
| """ | |
| Security module for AIBOM generator implementation. | |
| This module provides security functions that can be integrated | |
| into the AIBOM generator to improve input validation, error handling, | |
| and protection against common web vulnerabilities. | |
| """ | |
| import re | |
| import os | |
| import json | |
| import logging | |
| from typing import Dict, Any, Optional, Union | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def validate_model_id(model_id: str) -> str: | |
| """ | |
| Validate model ID to prevent injection attacks. | |
| Args: | |
| model_id: The model ID to validate | |
| Returns: | |
| The validated model ID | |
| Raises: | |
| ValueError: If the model ID contains invalid characters | |
| """ | |
| # Only allow alphanumeric characters, hyphens, underscores, and forward slashes | |
| if not model_id or not isinstance(model_id, str): | |
| raise ValueError("Model ID must be a non-empty string") | |
| if not re.match(r'^[a-zA-Z0-9_\-/]+$', model_id): | |
| raise ValueError(f"Invalid model ID format: {model_id}") | |
| # Prevent path traversal attempts | |
| if '..' in model_id: | |
| raise ValueError(f"Invalid model ID - contains path traversal sequence: {model_id}") | |
| return model_id | |
| def safe_path_join(directory: str, filename: str) -> str: | |
| """ | |
| Safely join directory and filename to prevent path traversal attacks. | |
| Args: | |
| directory: Base directory | |
| filename: Filename to append | |
| Returns: | |
| Safe file path | |
| """ | |
| # Ensure filename doesn't contain path traversal attempts | |
| filename = os.path.basename(filename) | |
| return os.path.join(directory, filename) | |
| def safe_json_parse(json_string: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| """ | |
| Safely parse JSON with error handling. | |
| Args: | |
| json_string: JSON string to parse | |
| default: Default value to return if parsing fails | |
| Returns: | |
| Parsed JSON object or default value | |
| """ | |
| if default is None: | |
| default = {} | |
| try: | |
| return json.loads(json_string) | |
| except (json.JSONDecodeError, TypeError) as e: | |
| logger.error(f"Invalid JSON: {e}") | |
| return default | |
| def sanitize_html_output(text: str) -> str: | |
| """ | |
| Sanitize text for safe HTML output to prevent XSS attacks. | |
| Args: | |
| text: Text to sanitize | |
| Returns: | |
| Sanitized text | |
| """ | |
| if not text or not isinstance(text, str): | |
| return "" | |
| # Replace HTML special characters with their entities | |
| replacements = { | |
| '&': '&', | |
| '<': '<', | |
| '>': '>', | |
| '"': '"', | |
| "'": ''', | |
| '/': '/', | |
| } | |
| for char, entity in replacements.items(): | |
| text = text.replace(char, entity) | |
| return text | |
| def secure_file_operations(file_path: str, operation: str, content: Optional[str] = None) -> Union[str, bool]: | |
| """ | |
| Perform secure file operations with proper error handling. | |
| Args: | |
| file_path: Path to the file | |
| operation: Operation to perform ('read', 'write', 'append') | |
| content: Content to write (for 'write' and 'append' operations) | |
| Returns: | |
| File content for 'read' operation, True for successful 'write'/'append', False otherwise | |
| """ | |
| try: | |
| if operation == 'read': | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| elif operation == 'write' and content is not None: | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| return True | |
| elif operation == 'append' and content is not None: | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(content) | |
| return True | |
| else: | |
| logger.error(f"Invalid file operation: {operation}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"File operation failed: {e}") | |
| return "" if operation == 'read' else False | |
| def validate_url(url: str) -> bool: | |
| """ | |
| Validate URL format to prevent malicious URL injection. | |
| Args: | |
| url: URL to validate | |
| Returns: | |
| True if URL is valid, False otherwise | |
| """ | |
| # Basic URL validation | |
| url_pattern = re.compile( | |
| r'^(https?):\/\/' # http:// or https:// | |
| r'(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*' # domain segments | |
| r'([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])' # last domain segment | |
| r'(:\d+)?' # optional port | |
| r'(\/[-a-zA-Z0-9%_.~#+]*)*' # path | |
| r'(\?[;&a-zA-Z0-9%_.~+=-]*)?' # query string | |
| r'(\#[-a-zA-Z0-9%_.~+=/]*)?$' # fragment | |
| ) | |
| return bool(url_pattern.match(url)) | |
| def secure_template_rendering(template_content: str, context: Dict[str, Any]) -> str: | |
| """ | |
| Render templates securely with auto-escaping enabled. | |
| This is a placeholder function. In a real implementation, you would use | |
| a template engine like Jinja2 with auto-escaping enabled. | |
| Args: | |
| template_content: Template content | |
| context: Context variables for rendering | |
| Returns: | |
| Rendered template | |
| """ | |
| try: | |
| from jinja2 import Template | |
| template = Template(template_content, autoescape=True) | |
| return template.render(**context) | |
| except ImportError: | |
| logger.error("Jinja2 not available, falling back to basic rendering") | |
| # Very basic fallback (not recommended for production) | |
| result = template_content | |
| for key, value in context.items(): | |
| if isinstance(value, str): | |
| placeholder = "{{" + key + "}}" | |
| result = result.replace(placeholder, sanitize_html_output(value)) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Template rendering failed: {e}") | |
| return "" | |
| def implement_rate_limiting(user_id: str, action: str, limit: int, period: int) -> bool: | |
| """ | |
| Implement basic rate limiting to prevent abuse. | |
| This is a placeholder function. In a real implementation, you would use | |
| a database or cache to track request counts. | |
| Args: | |
| user_id: Identifier for the user | |
| action: Action being performed | |
| limit: Maximum number of actions allowed | |
| period: Time period in seconds | |
| Returns: | |
| True if action is allowed, False if rate limit exceeded | |
| """ | |
| # In a real implementation, you would: | |
| # 1. Check if user has exceeded limit in the given period | |
| # 2. If not, increment counter and allow action | |
| # 3. If yes, deny action | |
| # Placeholder implementation always allows action | |
| logger.info(f"Rate limiting check for user {user_id}, action {action}") | |
| return True | |
| # Integration example for the AIBOM generator | |
| def secure_aibom_generation(model_id: str, output_file: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Example of how to integrate security improvements into AIBOM generation. | |
| Args: | |
| model_id: Model ID to generate AIBOM for | |
| output_file: Optional output file path | |
| Returns: | |
| Generated AIBOM data | |
| """ | |
| try: | |
| # Validate input | |
| validated_model_id = validate_model_id(model_id) | |
| # Process model ID securely | |
| # (This would call your actual AIBOM generation logic) | |
| aibom_data = {"message": f"AIBOM for {validated_model_id}"} | |
| # Handle output file securely if provided | |
| if output_file: | |
| safe_output_path = safe_path_join(os.path.dirname(output_file), os.path.basename(output_file)) | |
| secure_file_operations(safe_output_path, 'write', json.dumps(aibom_data, indent=2)) | |
| return aibom_data | |
| except ValueError as e: | |
| # Handle validation errors | |
| logger.error(f"Validation error: {e}") | |
| return {"error": "Invalid input parameters"} | |
| except Exception as e: | |
| # Handle unexpected errors | |
| logger.error(f"AIBOM generation failed: {e}") | |
| return {"error": "An internal error occurred"} | |