cotcotquedec commited on
Commit
17e63a8
·
1 Parent(s): 1ce4699

feat(auth): implement token-based authentication middleware

Browse files

Added a new authentication middleware to handle token-based authorization for all endpoints. This middleware checks for the presence of an authorization header, validates the bearer token, and stores it in a context variable for subsequent requests. This change enhances security by ensuring that only requests with valid tokens can access the API endpoints.

Additionally, modified the existing functions to utilize the token context for creating Anthropic clients, ensuring that the correct token is used for API interactions.

ref #123: enhance security with token-based auth middleware

Files changed (1) hide show
  1. main.py +55 -9
main.py CHANGED
@@ -1,16 +1,57 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException
3
  from fastapi.responses import JSONResponse, StreamingResponse
 
 
4
  from pydantic import BaseModel
5
  from typing import List, Optional
6
  from anthropic import Anthropic
7
  import json
8
  import time
 
9
 
10
  app = FastAPI()
 
11
 
12
- # Initialize Anthropic client with environment variable
13
- client = Anthropic(api_key=os.getenv('ANTHROPIC_API_KEY'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Available models
16
  AVAILABLE_MODELS = [
@@ -31,12 +72,15 @@ class ChatCompletionRequest(BaseModel):
31
  max_tokens: Optional[int] = 1024
32
 
33
  @app.get("/")
34
- def read_root():
35
  return {"Hello": "World!"}
36
 
37
  @app.get("/models")
38
  async def get_models():
39
  """Get available Anthropic models."""
 
 
 
40
  models = [
41
  {
42
  "id": model_id,
@@ -57,8 +101,6 @@ async def get_models():
57
  }
58
  )
59
 
60
- return {"data": models, "object": "list"}
61
-
62
  @app.post("/v1/chat/completions")
63
  async def create_chat_completion(request: ChatCompletionRequest):
64
  """Generate chat completions using Anthropic models."""
@@ -73,12 +115,13 @@ async def create_chat_completion(request: ChatCompletionRequest):
73
  except Exception as e:
74
  raise HTTPException(status_code=500, detail=str(e))
75
 
76
-
77
-
78
  async def generate_completion(request: ChatCompletionRequest):
79
  """Generate a non-streaming completion."""
80
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
81
 
 
 
 
82
  response = client.messages.create(
83
  model=request.model,
84
  max_tokens=request.max_tokens,
@@ -109,6 +152,9 @@ async def stream_response(request: ChatCompletionRequest):
109
  """Stream the completion response."""
110
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
111
 
 
 
 
112
  response = client.messages.create(
113
  model=request.model,
114
  max_tokens=request.max_tokens,
 
1
+ from fastapi import FastAPI, HTTPException, Request
 
2
  from fastapi.responses import JSONResponse, StreamingResponse
3
+ from fastapi.security import HTTPBearer
4
+ from fastapi.middleware.base import BaseHTTPMiddleware
5
  from pydantic import BaseModel
6
  from typing import List, Optional
7
  from anthropic import Anthropic
8
  import json
9
  import time
10
+ from contextvars import ContextVar
11
 
12
  app = FastAPI()
13
+ security = HTTPBearer()
14
 
15
+ # Context variable to store the token
16
+ token_context = ContextVar('token', default=None)
17
+
18
+ # Middleware pour récupérer et vérifier le token sur tous les endpoints
19
+ class AuthMiddleware(BaseHTTPMiddleware):
20
+ async def dispatch(self, request: Request, call_next):
21
+ try:
22
+ auth_header = request.headers.get('Authorization')
23
+ if not auth_header:
24
+ raise HTTPException(status_code=401, detail="No authorization header")
25
+
26
+ scheme, token = auth_header.split()
27
+ if scheme.lower() != 'bearer':
28
+ raise HTTPException(status_code=401, detail="Invalid authentication scheme")
29
+
30
+ # Store token in context
31
+ token_context.set(token)
32
+
33
+ except HTTPException as e:
34
+ return JSONResponse(
35
+ status_code=e.status_code,
36
+ content={"detail": e.detail}
37
+ )
38
+ except Exception as e:
39
+ return JSONResponse(
40
+ status_code=401,
41
+ content={"detail": "Invalid authorization header"}
42
+ )
43
+
44
+ return await call_next(request)
45
+
46
+ # Ajouter le middleware à l'application
47
+ app.add_middleware(AuthMiddleware)
48
+
49
+ # Function to get Anthropic client with current token
50
+ def get_anthropic_client():
51
+ token = token_context.get()
52
+ if not token:
53
+ raise HTTPException(status_code=401, detail="No authorization token found")
54
+ return Anthropic(api_key=token)
55
 
56
  # Available models
57
  AVAILABLE_MODELS = [
 
72
  max_tokens: Optional[int] = 1024
73
 
74
  @app.get("/")
75
+ async def read_root():
76
  return {"Hello": "World!"}
77
 
78
  @app.get("/models")
79
  async def get_models():
80
  """Get available Anthropic models."""
81
+ # Test the token by creating a client
82
+ get_anthropic_client()
83
+
84
  models = [
85
  {
86
  "id": model_id,
 
101
  }
102
  )
103
 
 
 
104
  @app.post("/v1/chat/completions")
105
  async def create_chat_completion(request: ChatCompletionRequest):
106
  """Generate chat completions using Anthropic models."""
 
115
  except Exception as e:
116
  raise HTTPException(status_code=500, detail=str(e))
117
 
 
 
118
  async def generate_completion(request: ChatCompletionRequest):
119
  """Generate a non-streaming completion."""
120
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
121
 
122
+ # Get client with current token
123
+ client = get_anthropic_client()
124
+
125
  response = client.messages.create(
126
  model=request.model,
127
  max_tokens=request.max_tokens,
 
152
  """Stream the completion response."""
153
  messages = [{"role": m.role, "content": m.content} for m in request.messages]
154
 
155
+ # Get client with current token
156
+ client = get_anthropic_client()
157
+
158
  response = client.messages.create(
159
  model=request.model,
160
  max_tokens=request.max_tokens,