File size: 2,724 Bytes
0e805d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Security utilities for input sanitization and rate limiting."""

import time
from collections import defaultdict
from typing import Dict, List

from core.config import (
    MAX_PROMPT_LENGTH,
    MAX_REQUESTS_PER_HOUR,
    REQUEST_WINDOW_SECONDS,
)


class SecurityManager:
    """Manages security features like rate limiting and input sanitization."""
    
    FORBIDDEN_CHARS = ['<', '>', '|', '&', ';', '`', '$', '(', ')', '\n', '\r', '\0', '\\']
    
    def __init__(self):
        self.user_requests: Dict[str, List[float]] = defaultdict(list)
    
    def sanitize_prompt(self, prompt: str) -> str:
        """Sanitize user input to prevent injection attacks."""
        if not prompt or not prompt.strip():
            raise ValueError("Prompt cannot be empty")
        
        # Remove control characters
        prompt = ''.join(c for c in prompt if c.isprintable())
        
        # Check length
        if len(prompt) > MAX_PROMPT_LENGTH:
            raise ValueError(
                f"Prompt too long (max {MAX_PROMPT_LENGTH} characters, got {len(prompt)})"
            )
        
        # Check for forbidden characters
        for char in self.FORBIDDEN_CHARS:
            if char in prompt:
                raise ValueError(f"Forbidden character in prompt: {char}")
        
        return prompt.strip()
    
    def check_rate_limit(self, user_id: str = "default") -> None:
        """Check if user has exceeded rate limit."""
        now = time.time()
        
        # Remove old requests outside the window
        self.user_requests[user_id] = [
            t for t in self.user_requests[user_id]
            if now - t < REQUEST_WINDOW_SECONDS
        ]
        
        # Check if limit exceeded
        if len(self.user_requests[user_id]) >= MAX_REQUESTS_PER_HOUR:
            remaining = REQUEST_WINDOW_SECONDS - (now - self.user_requests[user_id][0])
            raise ValueError(
                f"Rate limit exceeded. Try again in {int(remaining / 60)} minutes."
            )
        
        # Add current request
        self.user_requests[user_id].append(now)
        print(f"[Security] Rate limit: {len(self.user_requests[user_id])}/{MAX_REQUESTS_PER_HOUR}")
    
    def validate_file_size(self, file_path: str, max_size_mb: float = 100.0) -> float:
        """Validate file size before operations."""
        import os
        
        if not os.path.exists(file_path):
            raise ValueError(f"File not found: {file_path}")
        
        size_mb = os.path.getsize(file_path) / 1e6
        if size_mb > max_size_mb:
            raise ValueError(
                f"File too large: {size_mb:.2f}MB (max {max_size_mb:.2f}MB)"
            )
        
        return size_mb