Hivra commited on
Commit
41a8f4f
·
verified ·
1 Parent(s): 072a239

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +14 -4
app/main.py CHANGED
@@ -1,5 +1,6 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.responses import StreamingResponse, JSONResponse
 
3
  from pydantic import BaseModel
4
  from gradio_client import Client
5
  import time
@@ -11,6 +12,16 @@ DEFAULT_API = "/chat"
11
 
12
  client = Client(SPACE_ID)
13
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
16
  """
@@ -21,14 +32,13 @@ def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
21
  except Exception as e:
22
  raise RuntimeError(f"Gradio API error: {e}")
23
 
24
-
25
  class ChatRequest(BaseModel):
26
  message: str
27
  api_name: str = DEFAULT_API
28
 
29
  app = FastAPI()
30
 
31
- @app.post("/chat")
32
  async def chat_endpoint(req: ChatRequest):
33
  """Forward chat requests to the Gradio API."""
34
  try:
@@ -37,7 +47,7 @@ async def chat_endpoint(req: ChatRequest):
37
  except RuntimeError as e:
38
  raise HTTPException(status_code=502, detail=str(e))
39
 
40
- @app.post("/v1/chat/completions")
41
  async def openai_chat_completions(request: Request):
42
  """
43
  OpenAI-compatible chat completions endpoint that forwards to Gradio.
 
1
+ from fastapi import FastAPI, HTTPException, Request, Depends, Security
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
  from pydantic import BaseModel
5
  from gradio_client import Client
6
  import time
 
12
 
13
  client = Client(SPACE_ID)
14
 
15
+ # Security setup
16
+ security = HTTPBearer()
17
+ VALID_API_KEY = "sk-1234" # Replace with your actual API key
18
+
19
+ async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
20
+ if credentials.scheme != "Bearer":
21
+ raise HTTPException(status_code=403, detail="Invalid authentication scheme")
22
+ if credentials.credentials != VALID_API_KEY:
23
+ raise HTTPException(status_code=403, detail="Invalid API key")
24
+ return credentials.credentials
25
 
26
  def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
27
  """
 
32
  except Exception as e:
33
  raise RuntimeError(f"Gradio API error: {e}")
34
 
 
35
  class ChatRequest(BaseModel):
36
  message: str
37
  api_name: str = DEFAULT_API
38
 
39
  app = FastAPI()
40
 
41
+ @app.post("/chat", dependencies=[Depends(get_api_key)])
42
  async def chat_endpoint(req: ChatRequest):
43
  """Forward chat requests to the Gradio API."""
44
  try:
 
47
  except RuntimeError as e:
48
  raise HTTPException(status_code=502, detail=str(e))
49
 
50
+ @app.post("/v1/chat/completions", dependencies=[Depends(get_api_key)])
51
  async def openai_chat_completions(request: Request):
52
  """
53
  OpenAI-compatible chat completions endpoint that forwards to Gradio.