File size: 5,273 Bytes
b0b150b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

import time
from collections import defaultdict
from functools import wraps
from typing import Callable, Optional
import threading

from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse

class RateLimiter:
    """
    Simple in-memory rate limiter for API endpoints.
    Uses a sliding window algorithm.
    """
    
    def __init__(self):
        self._requests = defaultdict(list)
        self._lock = threading.RLock()
    
    def is_allowed(
        self, 
        key: str, 
        max_requests: int = 60, 
        window_seconds: int = 60
    ) -> tuple[bool, dict]:
        """
        Check if a request is allowed under rate limits.
        
        Returns: (is_allowed, info_dict)
        """
        with self._lock:
            now = time.time()
            window_start = now - window_seconds
            
            # Clean old requests
            self._requests[key] = [
                t for t in self._requests[key] if t > window_start
            ]
            
            current_count = len(self._requests[key])
            
            if current_count >= max_requests:
                retry_after = self._requests[key][0] - window_start
                return False, {
                    'limit': max_requests,
                    'remaining': 0,
                    'reset': int(self._requests[key][0] + window_seconds),
                    'retry_after': int(retry_after) + 1
                }
            
            # Add current request
            self._requests[key].append(now)
            
            return True, {
                'limit': max_requests,
                'remaining': max_requests - current_count - 1,
                'reset': int(now + window_seconds)
            }
    
    def reset(self, key: str):
        """Reset rate limit for a key."""
        with self._lock:
            if key in self._requests:
                del self._requests[key]


# Singleton instance
rate_limiter = RateLimiter()


# Rate limit configurations per endpoint type
RATE_LIMITS = {
    'auth': {'max_requests': 10, 'window': 60},        # 10 per minute
    'chat': {'max_requests': 30, 'window': 60},        # 30 per minute
    'compile': {'max_requests': 5, 'window': 300},     # 5 per 5 minutes
    'agents': {'max_requests': 60, 'window': 60},      # 60 per minute
    'default': {'max_requests': 100, 'window': 60}     # 100 per minute
}


async def rate_limit_middleware(request: Request, call_next):
    """
    FastAPI middleware for rate limiting.
    """
    # Get client identifier (IP or user ID if authenticated)
    client_ip = request.client.host if request.client else "unknown"
    
    # Determine endpoint type
    path = request.url.path
    if '/auth/' in path:
        limit_type = 'auth'
    elif '/chat/' in path:
        limit_type = 'chat'
    elif '/compile' in path:
        limit_type = 'compile'
    elif '/agents' in path:
        limit_type = 'agents'
    else:
        limit_type = 'default'
    
    # Check rate limit
    limits = RATE_LIMITS[limit_type]
    key = f"{client_ip}:{limit_type}"
    
    allowed, info = rate_limiter.is_allowed(
        key, 
        max_requests=limits['max_requests'],
        window_seconds=limits['window']
    )
    
    if not allowed:
        return JSONResponse(
            status_code=429,
            content={
                'detail': 'Too many requests',
                'retry_after': info['retry_after']
            },
            headers={
                'X-RateLimit-Limit': str(info['limit']),
                'X-RateLimit-Remaining': str(info['remaining']),
                'X-RateLimit-Reset': str(info['reset']),
                'Retry-After': str(info['retry_after'])
            }
        )
    
    # Process request
    response = await call_next(request)
    
    # Add rate limit headers
    response.headers['X-RateLimit-Limit'] = str(info['limit'])
    response.headers['X-RateLimit-Remaining'] = str(info['remaining'])
    response.headers['X-RateLimit-Reset'] = str(info['reset'])
    
    return response


# File validation constants
MAX_FILE_SIZE = 50 * 1024 * 1024  # 50MB
ALLOWED_EXTENSIONS = {'.csv', '.pdf', '.docx', '.txt', '.json', '.xlsx'}

def validate_file_upload(filename: str, file_size: int) -> Optional[str]:
    """
    Validate an uploaded file.
    Returns error message if invalid, None if valid.
    """
    import os
    
    # Check extension
    ext = os.path.splitext(filename)[1].lower()
    if ext not in ALLOWED_EXTENSIONS:
        return f"File type '{ext}' not allowed. Allowed types: {', '.join(ALLOWED_EXTENSIONS)}"
    
    # Check size
    if file_size > MAX_FILE_SIZE:
        max_mb = MAX_FILE_SIZE / (1024 * 1024)
        return f"File too large. Maximum size is {max_mb}MB"
    
    return None


# Security headers middleware
async def security_headers_middleware(request: Request, call_next):
    """Add security headers to all responses."""
    response = await call_next(request)
    
    response.headers['X-Content-Type-Options'] = 'nosniff'
    response.headers['X-Frame-Options'] = 'DENY'
    response.headers['X-XSS-Protection'] = '1; mode=block'
    response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
    
    return response