lainlives commited on
Commit
6cab965
·
verified ·
1 Parent(s): 3f05082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -1,10 +1,13 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
- from typing import List
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
-
 
7
  app = FastAPI()
 
8
 
9
  # Load model
10
  quant_config = BitsAndBytesConfig(
@@ -20,6 +23,14 @@ model = AutoModelForCausalLM.from_pretrained(
20
  quantization_config=quant_config,
21
  device_map="auto" # Use 'auto' to let it balance between GPU/CPU
22
  )
 
 
 
 
 
 
 
 
23
  class Message(BaseModel):
24
  role: str
25
  content: str
@@ -30,7 +41,7 @@ class ChatRequest(BaseModel):
30
  temperature: float = 0.7
31
 
32
  @app.post("/v1/chat/completions")
33
- async def chat_endpoint(request: ChatRequest):
34
  # Get the last user message
35
  prompt = request.messages[-1].content
36
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
  from pydantic import BaseModel
4
+ from typing import List, Optional
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ HF_TOKEN = os.getenv('HF_TOKEN')
8
+ API_KEY = os.getenv('API_KEY')
9
  app = FastAPI()
10
+ security = HTTPBearer()
11
 
12
  # Load model
13
  quant_config = BitsAndBytesConfig(
 
23
  quantization_config=quant_config,
24
  device_map="auto" # Use 'auto' to let it balance between GPU/CPU
25
  )
26
+ def validate_api_key(auth: HTTPAuthorizationCredentials = Depends(security)):
27
+ if auth.credentials != API_KEY:
28
+ raise HTTPException(
29
+ status_code=status.HTTP_401_UNAUTHORIZED,
30
+ detail="Invalid or missing API Key",
31
+ headers={"WWW-Authenticate": "Bearer"},
32
+ )
33
+ return auth.credentials
34
  class Message(BaseModel):
35
  role: str
36
  content: str
 
41
  temperature: float = 0.7
42
 
43
  @app.post("/v1/chat/completions")
44
+ async def chat_endpoint(request: ChatRequest, token: str = Depends(validate_api_key)):
45
  # Get the last user message
46
  prompt = request.messages[-1].content
47