File size: 5,609 Bytes
2a8faae
0176a31
2a8faae
 
 
ddc9c77
 
2a8faae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddc9c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8faae
 
0f194af
 
 
 
 
 
0a5dcf9
0f194af
0176a31
 
0f194af
0a5dcf9
 
 
 
 
 
0f194af
 
2a8faae
0f194af
2a8faae
0176a31
 
 
2a8faae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Middleware for Lung Cancer AI Advisor API
"""
import time
import logging
from typing import Callable, Awaitable, Optional
from fastapi import Request, Response, HTTPException, Cookie
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)


class ProcessTimeMiddleware(BaseHTTPMiddleware):
    """Middleware to add processing time to response headers"""
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        start_time = time.time()
        response = await call_next(request)
        process_time = time.time() - start_time
        response.headers["X-Process-Time"] = f"{process_time:.4f}"
        return response


class LoggingMiddleware(BaseHTTPMiddleware):
    """Middleware for request/response logging"""
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        start_time = time.time()
        
        # Log request
        logger.info(f"Request: {request.method} {request.url}")
        
        try:
            response = await call_next(request)
            process_time = time.time() - start_time
            
            # Log response
            logger.info(
                f"Response: {response.status_code} - "
                f"Time: {process_time:.4f}s - "
                f"Path: {request.url.path}"
            )
            
            return response
            
        except Exception as e:
            process_time = time.time() - start_time
            logger.error(
                f"Error: {str(e)} - "
                f"Time: {process_time:.4f}s - "
                f"Path: {request.url.path}"
            )
            raise


class RateLimitMiddleware(BaseHTTPMiddleware):
    """Simple rate limiting middleware"""
    
    def __init__(self, app, calls_per_minute: int = 60):
        super().__init__(app)
        self.calls_per_minute = calls_per_minute
        self.client_calls = {}
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        client_ip = request.client.host
        current_time = time.time()
        
        # Clean old entries
        self.client_calls = {
            ip: calls for ip, calls in self.client_calls.items()
            if any(call_time > current_time - 60 for call_time in calls)
        }
        
        # Check rate limit
        if client_ip in self.client_calls:
            recent_calls = [
                call_time for call_time in self.client_calls[client_ip]
                if call_time > current_time - 60
            ]
            if len(recent_calls) >= self.calls_per_minute:
                raise HTTPException(
                    status_code=429,
                    detail="Rate limit exceeded. Please try again later."
                )
            self.client_calls[client_ip] = recent_calls + [current_time]
        else:
            self.client_calls[client_ip] = [current_time]
        
        return await call_next(request)


class AuthenticationMiddleware(BaseHTTPMiddleware):
    """Middleware to protect endpoints with session authentication"""
    
    # Paths that don't require authentication
    PUBLIC_PATHS = [
        "/",
        "/docs",
        "/redoc",
        "/openapi.json",
        "/health",
        "/auth/login",
        "/auth/status",
    ]
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        # Check if path is public
        path = request.url.path
        
        # Allow public paths
        if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS):
            return await call_next(request)
        
        # Check for session token
        session_token = request.cookies.get("session_token")
        
        if not session_token:
            raise HTTPException(
                status_code=401,
                detail="Authentication required"
            )
        
        # Verify session
        from api.routers.auth import verify_session
        session_data = verify_session(session_token)
        
        if not session_data:
            raise HTTPException(
                status_code=401,
                detail="Invalid or expired session"
            )
        
        # Add user info to request state
        request.state.user = session_data.get("username")
        
        return await call_next(request)


def get_cors_middleware_config():
    """Get CORS middleware configuration"""
    import os
    
    # Get allowed origins from environment or use defaults
    allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",")
    if not allowed_origins or allowed_origins == [""]:
        # Default to allowing Hugging Face Space and localhost
        # Include null for file:// protocol and common local development origins
        allowed_origins = [
            "http://127.0.0.1:7860",
            "https://huggingface.co",
            "http://localhost:8000",
            "http://127.0.0.1:8000",
            "http://localhost:5500",  # Live Server default port
            "http://127.0.0.1:5500",
            "http://localhost:3000",  # Common dev server port
            "http://127.0.0.1:3000",
            "null"  # For file:// protocol
        ]
    
    return {
        "allow_origins": allowed_origins,
        "allow_credentials": True,
        "allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
        "allow_headers": ["Content-Type", "Authorization", "Accept", "Origin", "X-Requested-With", "Cookie"],
        "expose_headers": ["Set-Cookie"],
    }