Hivra commited on
Commit
072a239
·
verified ·
1 Parent(s): d5f83a2

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +26 -77
app/main.py CHANGED
@@ -1,100 +1,51 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from fastapi import FastAPI, HTTPException, Request, Depends, Header
4
  from fastapi.responses import StreamingResponse, JSONResponse
5
  from pydantic import BaseModel
6
- from gradio_client import Client, utils
7
- import httpx
8
  import time
9
  import json
10
 
11
- # Load environment variables
12
- load_dotenv()
 
13
 
14
- # Configuration
15
- SPACE_ID = os.getenv("SPACE_ID", "prithivMLmods/SAMBANOVA")
16
- DEFAULT_API = os.getenv("DEFAULT_API", "/chat")
17
- GRADIO_TIMEOUT = int(os.getenv("GRADIO_TIMEOUT", "60"))
18
- API_KEY = os.getenv("API_KEY")
19
- if not API_KEY:
20
- raise RuntimeError("Missing API_KEY in environment")
21
-
22
- # Lazy Gradio client initialization
23
- global_client = None
24
-
25
- def get_gradio_client():
26
- """Initialize or return cached Gradio client, retrying on rate limits or timeouts."""
27
- global global_client
28
- if global_client:
29
- return global_client
30
- # Try up to 3 times with exponential backoff
31
- for attempt in range(3):
32
- try:
33
- client = Client(SPACE_ID)
34
- # set HTTPX timeouts (connect quick, allow longer reads)
35
- client.client.timeout = httpx.Timeout(connect=5.0, read=GRADIO_TIMEOUT)
36
- global_client = client
37
- return client
38
- except utils.TooManyRequestsError:
39
- if attempt < 2:
40
- time.sleep(2 ** attempt)
41
- continue
42
- raise RuntimeError("Gradio API config rate-limited. Please try again later.")
43
- except Exception as e:
44
- msg = str(e)
45
- if "ReadTimeout" in msg and attempt < 2:
46
- # retry on read timeouts
47
- time.sleep(2 ** attempt)
48
- continue
49
- raise RuntimeError(f"Failed to initialize Gradio client: {e}")
50
- except utils.TooManyRequestsError:
51
- raise RuntimeError("Gradio API config rate-limited. Please try again later.")
52
- except Exception as e:
53
- raise RuntimeError(f"Failed to initialize Gradio client: {e}")
54
 
55
 
56
  def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
57
- client = get_gradio_client()
 
 
58
  try:
59
  return client.predict(message=message, api_name=api_name)
60
  except Exception as e:
61
- msg = str(e)
62
- if "ReadTimeout" in msg:
63
- raise RuntimeError(f"Gradio API timed out after {GRADIO_TIMEOUT}s")
64
  raise RuntimeError(f"Gradio API error: {e}")
65
 
66
 
67
- def verify_api_key(
68
- x_api_key: str = Header(None),
69
- authorization: str = Header(None)
70
- ):
71
- """Accepts either X-API-Key or Authorization: Bearer <key>"""
72
- token = x_api_key
73
- if not token and authorization:
74
- scheme, _, cred = authorization.partition(' ')
75
- if scheme.lower() == 'bearer':
76
- token = cred
77
- if token != API_KEY:
78
- raise HTTPException(status_code=401, detail="Invalid or missing API Key")
79
-
80
  class ChatRequest(BaseModel):
81
  message: str
82
  api_name: str = DEFAULT_API
83
 
84
  app = FastAPI()
85
 
86
- @app.post("/chat", dependencies=[Depends(verify_api_key)])
87
  async def chat_endpoint(req: ChatRequest):
 
88
  try:
89
  reply = chat_with_gradio(req.message, req.api_name)
90
  return {"reply": reply}
91
  except RuntimeError as e:
92
  raise HTTPException(status_code=502, detail=str(e))
93
 
94
- @app.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)])
95
  async def openai_chat_completions(request: Request):
 
 
 
 
96
  body = await request.json()
97
  messages = body.get("messages")
 
98
  stream = body.get("stream", False)
99
 
100
  if not messages or not isinstance(messages, list):
@@ -102,25 +53,25 @@ async def openai_chat_completions(request: Request):
102
 
103
  user_msg = messages[-1].get("content", "")
104
 
 
105
  try:
106
  reply = chat_with_gradio(user_msg, DEFAULT_API)
107
  except RuntimeError as e:
108
  raise HTTPException(status_code=502, detail=str(e))
109
 
 
110
  prompt_tokens = sum(len(m.get("content", "").split()) for m in messages)
111
  completion_tokens = len(str(reply).split())
112
- usage = {
113
- "prompt_tokens": prompt_tokens,
114
- "completion_tokens": completion_tokens,
115
- "total_tokens": prompt_tokens + completion_tokens
116
- }
117
 
118
  if stream:
 
119
  def event_generator():
120
  for word in str(reply).split():
121
- chunk = {"choices": [{"delta": {"content": word + " "}, "index": 0, "finish_reason": None}]}
122
  yield f"data: {json.dumps(chunk)}\n\n"
123
  time.sleep(0.05)
 
124
  done = {"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]}
125
  yield f"data: {json.dumps(done)}\n\n"
126
  return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -129,7 +80,7 @@ async def openai_chat_completions(request: Request):
129
  "id": f"chatcmpl-{int(time.time())}",
130
  "object": "chat.completion",
131
  "created": int(time.time()),
132
- "model": body.get("model"),
133
  "choices": [{"index": 0, "message": {"role": "assistant", "content": reply}, "finish_reason": "stop"}],
134
  "usage": usage
135
  }
@@ -137,7 +88,5 @@ async def openai_chat_completions(request: Request):
137
 
138
  if __name__ == "__main__":
139
  import uvicorn
140
- print(
141
- f"Starting server on http://0.0.0.0:7860 using Space {SPACE_ID}{DEFAULT_API}"
142
- )
143
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, HTTPException, Request
 
 
2
  from fastapi.responses import StreamingResponse, JSONResponse
3
  from pydantic import BaseModel
4
+ from gradio_client import Client
 
5
  import time
6
  import json
7
 
8
+ # Configure your Gradio Space ID and default endpoint
9
+ SPACE_ID = "prithivMLmods/SAMBANOVA"
10
+ DEFAULT_API = "/chat"
11
 
12
+ client = Client(SPACE_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
16
+ """
17
+ Send a chat message to the Gradio API and return the response.
18
+ """
19
  try:
20
  return client.predict(message=message, api_name=api_name)
21
  except Exception as e:
 
 
 
22
  raise RuntimeError(f"Gradio API error: {e}")
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class ChatRequest(BaseModel):
26
  message: str
27
  api_name: str = DEFAULT_API
28
 
29
  app = FastAPI()
30
 
31
+ @app.post("/chat")
32
  async def chat_endpoint(req: ChatRequest):
33
+ """Forward chat requests to the Gradio API."""
34
  try:
35
  reply = chat_with_gradio(req.message, req.api_name)
36
  return {"reply": reply}
37
  except RuntimeError as e:
38
  raise HTTPException(status_code=502, detail=str(e))
39
 
40
+ @app.post("/v1/chat/completions")
41
  async def openai_chat_completions(request: Request):
42
+ """
43
+ OpenAI-compatible chat completions endpoint that forwards to Gradio.
44
+ Supports both streaming and non-streaming.
45
+ """
46
  body = await request.json()
47
  messages = body.get("messages")
48
+ model = body.get("model")
49
  stream = body.get("stream", False)
50
 
51
  if not messages or not isinstance(messages, list):
 
53
 
54
  user_msg = messages[-1].get("content", "")
55
 
56
+ # Call Gradio
57
  try:
58
  reply = chat_with_gradio(user_msg, DEFAULT_API)
59
  except RuntimeError as e:
60
  raise HTTPException(status_code=502, detail=str(e))
61
 
62
+ # Build usage (simple token count by words)
63
  prompt_tokens = sum(len(m.get("content", "").split()) for m in messages)
64
  completion_tokens = len(str(reply).split())
65
+ usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens}
 
 
 
 
66
 
67
  if stream:
68
+ # Stream word by word as OpenAI SSE
69
  def event_generator():
70
  for word in str(reply).split():
71
+ chunk = {"choices": [{"delta": {"content": word+" "}, "index": 0, "finish_reason": None}]}
72
  yield f"data: {json.dumps(chunk)}\n\n"
73
  time.sleep(0.05)
74
+ # send done
75
  done = {"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]}
76
  yield f"data: {json.dumps(done)}\n\n"
77
  return StreamingResponse(event_generator(), media_type="text/event-stream")
 
80
  "id": f"chatcmpl-{int(time.time())}",
81
  "object": "chat.completion",
82
  "created": int(time.time()),
83
+ "model": model,
84
  "choices": [{"index": 0, "message": {"role": "assistant", "content": reply}, "finish_reason": "stop"}],
85
  "usage": usage
86
  }
 
88
 
89
  if __name__ == "__main__":
90
  import uvicorn
91
+ print(f"Starting server on http://0.0.0.0:7860 using {SPACE_ID}{DEFAULT_API} and OpenAI-compatible endpoint /v1/chat/completions")
92
+ uvicorn.run(app, host="0.0.0.0", port=7860)