File size: 4,919 Bytes
80e0598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
from collections import defaultdict
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import asyncio  # Concurrency limiting

logger = logging.getLogger(__name__)

class RateLimitMiddleware(BaseHTTPMiddleware):
    def __init__(
        self, 
        app, 
        rate_limit_per_minute=10,
        rate_limit_window=60,
        protected_routes=["/generate", "/api/generate", "/api/generate-with-report"]
    ):
        super().__init__(app)
        self.rate_limit_per_minute = rate_limit_per_minute
        self.rate_limit_window = rate_limit_window
        self.protected_routes = protected_routes
        self.ip_requests = defaultdict(list)
        logger.info(f"Rate limit middleware initialized: {rate_limit_per_minute} requests per {rate_limit_window}s")
        
    async def dispatch(self, request: Request, call_next):
        client_ip = request.client.host
        current_time = time.time()
        
        # Only apply rate limiting to protected routes
        if any(request.url.path.startswith(route) for route in self.protected_routes):
            # Clean up old requests
            self.ip_requests[client_ip] = [t for t in self.ip_requests[client_ip] 
                                          if current_time - t < self.rate_limit_window]
            
            # Check if rate limit exceeded
            if len(self.ip_requests[client_ip]) >= self.rate_limit_per_minute:
                logger.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}")
                return JSONResponse(
                    status_code=429,
                    content={"detail": "Rate limit exceeded. Please try again later."}
                )
            
            # Add current request timestamp
            self.ip_requests[client_ip].append(current_time)
        
        # Process the request
        response = await call_next(request)
        return response

class ConcurrencyLimitMiddleware(BaseHTTPMiddleware):
    def __init__(
        self, 
        app, 
        max_concurrent_requests=5,
        timeout=5.0,
        protected_routes=None
    ):
        super().__init__(app)
        self.semaphore = asyncio.Semaphore(max_concurrent_requests)
        self.timeout = timeout
        self.protected_routes = protected_routes or ["/generate", "/api/generate", "/api/generate-with-report"]
        logger.info(f"Concurrency limit middleware initialized: {max_concurrent_requests} concurrent requests")
        
    async def dispatch(self, request, call_next):
        try:
            # Only apply to protected routes
            if any(request.url.path.startswith(route) for route in self.protected_routes):
                try:
                    # Try to acquire the semaphore
                    acquired = False
                    try:
                        # Use wait_for instead of timeout context manager for compatibility
                        await asyncio.wait_for(self.semaphore.acquire(), timeout=self.timeout)
                        acquired = True
                        return await call_next(request)
                    finally:
                        if acquired:
                            self.semaphore.release()
                except asyncio.TimeoutError:
                    # Timeout waiting for semaphore
                    logger.warning(f"Concurrency limit reached for {request.url.path}")
                    return JSONResponse(
                        status_code=503, 
                        content={"detail": "Server is at capacity. Please try again later."}
                    )
            else:
                # For non-protected routes, proceed normally
                return await call_next(request)
        except Exception as e:
            logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}")
            return JSONResponse(
                status_code=500,
                content={"detail": f"Internal server error in middleware: {str(e)}"}
            )


# Protection against large request payloads
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_content_length=1024*1024):  # 1MB default
        super().__init__(app)
        self.max_content_length = max_content_length
        logger.info(f"Request size limit middleware initialized: {max_content_length} bytes")
        
    async def dispatch(self, request: Request, call_next):
        content_length = request.headers.get('content-length')
        if content_length:
            if int(content_length) > self.max_content_length:
                logger.warning(f"Request too large: {content_length} bytes")
                return JSONResponse(
                    status_code=413,
                    content={"detail": "Request too large"}
                )
        return await call_next(request)