geqintan commited on
Commit
04d75a1
·
1 Parent(s): 408f2cf
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -1,12 +1,23 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel, Field
3
  from sentence_transformers import SentenceTransformer
4
- import logging
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI()
11
 
12
  try:
@@ -20,9 +31,10 @@ except Exception as e:
20
  class RerankerRequest(BaseModel):
21
  query: str = Field(..., min_length=1, max_length=1000, description="The query text.")
22
  documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.")
 
23
 
24
  @app.post("/rerank")
25
- async def rerank(request: RerankerRequest):
26
  query = request.query
27
  documents = request.documents
28
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Header
2
  from pydantic import BaseModel, Field
3
  from sentence_transformers import SentenceTransformer
4
+ import logging, os
5
 
6
  # Configure logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
+ # 定义依赖项来校验 Authorization
11
+ async def check_authorization(authorization: str = Header(..., alias="Authorization")):
12
+ # 去掉 Bearer 和后面的空格
13
+ if not authorization.startswith("Bearer "):
14
+ raise HTTPException(status_code=401, detail="Invalid Authorization header format")
15
+
16
+ token = authorization[len("Bearer "):]
17
+ if token != os.environ.get("AUTHORIZATION"):
18
+ raise HTTPException(status_code=401, detail="Unauthorized access")
19
+ return token
20
+
21
  app = FastAPI()
22
 
23
  try:
 
31
  class RerankerRequest(BaseModel):
32
  query: str = Field(..., min_length=1, max_length=1000, description="The query text.")
33
  documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.")
34
+ truncate: bool = Field(False, description="Whether to truncate the documents.")
35
 
36
  @app.post("/rerank")
37
+ async def rerank(request: RerankerRequest, authorization: str = Depends(check_authorization)):
38
  query = request.query
39
  documents = request.documents
40