Spaces:
Runtime error
Runtime error
| """Security utilities for input sanitization and rate limiting.""" | |
| import time | |
| from collections import defaultdict | |
| from typing import Dict, List | |
| from core.config import ( | |
| MAX_PROMPT_LENGTH, | |
| MAX_REQUESTS_PER_HOUR, | |
| REQUEST_WINDOW_SECONDS, | |
| ) | |
| class SecurityManager: | |
| """Manages security features like rate limiting and input sanitization.""" | |
| FORBIDDEN_CHARS = ['<', '>', '|', '&', ';', '`', '$', '(', ')', '\n', '\r', '\0', '\\'] | |
| def __init__(self): | |
| self.user_requests: Dict[str, List[float]] = defaultdict(list) | |
| def sanitize_prompt(self, prompt: str) -> str: | |
| """Sanitize user input to prevent injection attacks.""" | |
| if not prompt or not prompt.strip(): | |
| raise ValueError("Prompt cannot be empty") | |
| # Remove control characters | |
| prompt = ''.join(c for c in prompt if c.isprintable()) | |
| # Check length | |
| if len(prompt) > MAX_PROMPT_LENGTH: | |
| raise ValueError( | |
| f"Prompt too long (max {MAX_PROMPT_LENGTH} characters, got {len(prompt)})" | |
| ) | |
| # Check for forbidden characters | |
| for char in self.FORBIDDEN_CHARS: | |
| if char in prompt: | |
| raise ValueError(f"Forbidden character in prompt: {char}") | |
| return prompt.strip() | |
| def check_rate_limit(self, user_id: str = "default") -> None: | |
| """Check if user has exceeded rate limit.""" | |
| now = time.time() | |
| # Remove old requests outside the window | |
| self.user_requests[user_id] = [ | |
| t for t in self.user_requests[user_id] | |
| if now - t < REQUEST_WINDOW_SECONDS | |
| ] | |
| # Check if limit exceeded | |
| if len(self.user_requests[user_id]) >= MAX_REQUESTS_PER_HOUR: | |
| remaining = REQUEST_WINDOW_SECONDS - (now - self.user_requests[user_id][0]) | |
| raise ValueError( | |
| f"Rate limit exceeded. Try again in {int(remaining / 60)} minutes." | |
| ) | |
| # Add current request | |
| self.user_requests[user_id].append(now) | |
| print(f"[Security] Rate limit: {len(self.user_requests[user_id])}/{MAX_REQUESTS_PER_HOUR}") | |
| def validate_file_size(self, file_path: str, max_size_mb: float = 100.0) -> float: | |
| """Validate file size before operations.""" | |
| import os | |
| if not os.path.exists(file_path): | |
| raise ValueError(f"File not found: {file_path}") | |
| size_mb = os.path.getsize(file_path) / 1e6 | |
| if size_mb > max_size_mb: | |
| raise ValueError( | |
| f"File too large: {size_mb:.2f}MB (max {max_size_mb:.2f}MB)" | |
| ) | |
| return size_mb | |