mowan000 commited on
Commit
7e200cb
·
verified ·
1 Parent(s): fa4cba4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +379 -51
app.py CHANGED
@@ -1,17 +1,26 @@
1
- from fastapi import FastAPI, Request, Header
2
- from fastapi.responses import StreamingResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
- import httpx
 
 
 
5
  import logging
6
- import os
 
 
 
 
 
 
 
7
 
8
- # Configure logging
9
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
10
  logger = logging.getLogger(__name__)
11
 
12
  app = FastAPI()
13
 
14
- # Add CORS middleware to allow all origins
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -20,53 +29,374 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- TARGET_HOST = "https://generativelanguage.googleapis.com"
24
 
25
- # Use a persistent client for connection pooling
26
- client = httpx.AsyncClient(base_url=TARGET_HOST)
 
 
 
 
 
27
 
28
- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
29
- async def reverse_proxy(request: Request, path: str, authorization: str = Header(None)):
30
- """
31
- A reverse proxy that forwards requests to the Google Generative Language API.
32
- It extracts the API key from the Authorization header.
33
- """
34
- if not authorization or not authorization.startswith("Bearer "):
35
- return {"error": "Authorization header with Bearer token is required."}, 401
36
-
37
- api_key = authorization.replace("Bearer ", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Build the target URL
40
- url = httpx.URL(path=f"/{path}", query=request.url.query.encode("utf-8"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Add the API key to the query parameters
43
- # httpx URL query parameters are immutable, so we create a new one
44
- params = list(url.query)
45
- params.append(("key", api_key))
46
- url = url.copy_with(query=b'&'.join([f"{k}={v}".encode() for k, v in params]))
47
-
48
- logger.info(f"Forwarding request to: {url}")
49
-
50
- # Prepare the request to be forwarded
51
- rp_req = client.build_request(
52
- method=request.method,
53
- url=url,
54
- headers=request.headers,
55
- content=await request.body(),
56
- )
57
 
58
- # Stream the response back to the client
59
- try:
60
- rp_resp = await client.send(rp_req, stream=True)
61
- except httpx.RequestError as e:
62
- logger.error(f"An error occurred while requesting {url}: {e}")
63
- return {"error": "Failed to connect to the upstream server."}, 502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return StreamingResponse(
66
- rp_resp.aiter_raw(),
67
- status_code=rp_resp.status_code,
68
- headers=rp_resp.headers,
69
- )
70
 
71
  @app.get("/health")
72
  @app.get("/")
@@ -75,6 +405,4 @@ async def health_check():
75
  return {"status": "healthy"}
76
 
77
  if __name__ == "__main__":
78
- import uvicorn
79
- port = int(os.environ.get("PORT", 8080))
80
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ from fastapi import FastAPI, HTTPException, Header
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ import openai
6
+ from typing import List, Optional, Union
7
  import logging
8
+ import httpx
9
+ import uuid
10
+ import time
11
+ import json
12
+ from datetime import datetime, timezone
13
+ import requests
14
+ import uvicorn
15
+ import random
16
 
17
+ logging.basicConfig(
18
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
19
+ )
20
  logger = logging.getLogger(__name__)
21
 
22
  app = FastAPI()
23
 
 
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"],
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ MAX_RETRIES = 3
33
 
34
+ class ChatRequest(BaseModel):
35
+ messages: List[dict]
36
+ model: str
37
+ temperature: Optional[float] = 0.7
38
+ stream: Optional[bool] = False
39
+ tools: Optional[List[dict]] = []
40
+ tool_choice: Optional[str] = "auto"
41
 
42
+ class EmbeddingRequest(BaseModel):
43
+ input: Union[str, List[str]]
44
+ model: str
45
+ encoding_format: Optional[str] = "float"
46
+
47
+ async def verify_authorization(authorization: str = Header(None)):
48
+ print("Authorization header:", authorization)
49
+ if not authorization:
50
+ logger.error("Missing Authorization header")
51
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
52
+ if not authorization.startswith("Bearer "):
53
+ logger.error("Invalid Authorization header format")
54
+ raise HTTPException(
55
+ status_code=401, detail="Invalid Authorization header format"
56
+ )
57
+ token = authorization.replace("Bearer ", "")
58
+ return token
59
+
60
+ def get_openai_models(api_keys):
61
+ api_key = random.choice(api_keys)
62
+ try:
63
+ client = openai.OpenAI(api_key=api_key)
64
+ models = client.models.list()
65
+ return models.model_dump()
66
+ except Exception as e:
67
+ logger.error(f"Error getting models from OpenAI with key {api_key}: {e}")
68
+ return {"error": str(e)}
69
+
70
+ def get_gemini_models(api_keys):
71
+ api_key = random.choice(api_keys)
72
+ base_url = "https://generativelanguage.googleapis.com/v1beta"
73
+ url = f"{base_url}/models?key={api_key}"
74
+
75
+ try:
76
+ response = requests.get(url)
77
+ if response.status_code == 200:
78
+ gemini_models = response.json()
79
+ return convert_to_openai_models_format(gemini_models)
80
+ else:
81
+ logger.error(f"Error getting models from Gemini with key {api_key}: {response.status_code} - {response.text}")
82
+ return {"error": f"Gemini API error: {response.status_code} - {response.text}"}
83
+
84
+ except requests.RequestException as e:
85
+ logger.error(f"Request failed: {e}")
86
+ return {"error": f"Request failed: {e}"}
87
+
88
+ def convert_to_openai_models_format(gemini_models):
89
+ openai_format = {"object": "list", "data": []}
90
+
91
+ for model in gemini_models.get("models", []):
92
+ openai_model = {
93
+ "id": model["name"].split("/")[-1],
94
+ "object": "model",
95
+ "created": int(datetime.now(timezone.utc).timestamp()),
96
+ "owned_by": "google",
97
+ "permission": [],
98
+ "root": model["name"],
99
+ "parent": None,
100
+ }
101
+ openai_format["data"].append(openai_model)
102
+
103
+ return openai_format
104
+
105
+ def convert_messages_to_gemini_format(messages):
106
+ gemini_messages = []
107
+ for msg in messages:
108
+ role = "user" if msg["role"] == "user" else "model"
109
+ parts = []
110
+ if isinstance(msg["content"], str):
111
+ parts.append({"text": msg["content"]})
112
+ elif isinstance(msg["content"], list):
113
+ for content in msg["content"]:
114
+ if isinstance(content, str):
115
+ parts.append({"text": content})
116
+ elif isinstance(content, dict) and content["type"] == "text":
117
+ parts.append({"text": content["text"]})
118
+ elif isinstance(content, dict) and content["type"] == "image_url":
119
+ image_url = content["image_url"]["url"]
120
+ if image_url.startswith("data:image"):
121
+ parts.append(
122
+ {
123
+ "inline_data": {
124
+ "mime_type": "image/jpeg",
125
+ "data": image_url.split(",")[1],
126
+ }
127
+ }
128
+ )
129
+ else:
130
+ parts.append(
131
+ {
132
+ "image_url": {
133
+ "url": image_url,
134
+ }
135
+ }
136
+ )
137
+ gemini_messages.append({"role": role, "parts": parts})
138
+ return gemini_messages
139
 
140
+ async def convert_gemini_response_to_openai(response, model, stream=False):
141
+ if stream:
142
+ chunk = response
143
+ if not chunk["candidates"]:
144
+ return None
145
+
146
+ return {
147
+ "id": "chatcmpl-" + str(uuid.uuid4()),
148
+ "object": "chat.completion.chunk",
149
+ "created": int(time.time()),
150
+ "model": model,
151
+ "choices": [
152
+ {
153
+ "index": 0,
154
+ "delta": {
155
+ "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
156
+ },
157
+ "finish_reason": None,
158
+ }
159
+ ],
160
+ }
161
+ else:
162
+ content = response["candidates"][0]["content"]["parts"][0]["text"]
163
+ return {
164
+ "id": "chatcmpl-" + str(uuid.uuid4()),
165
+ "object": "chat.completion",
166
+ "created": int(time.time()),
167
+ "model": model,
168
+ "choices": [
169
+ {
170
+ "index": 0,
171
+ "message": {
172
+ "role": "assistant",
173
+ "content": content,
174
+ },
175
+ "finish_reason": "stop",
176
+ }
177
+ ],
178
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
179
+ }
180
+
181
+ @app.get("/v1/models")
182
+ @app.get("/hf/v1/models")
183
+ async def list_models(authorization: str = Header(None)):
184
+ token = await verify_authorization(authorization)
185
+ api_keys = [key.strip() for key in token.split(',')]
186
 
187
+ all_models = []
188
+ error_messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ for api_key in api_keys:
191
+ if api_key.startswith("sk-"):
192
+ response = get_openai_models([api_key])
193
+ else:
194
+ response = get_gemini_models([api_key])
195
+
196
+ if "error" in response:
197
+ error_messages.append(response["error"])
198
+ else:
199
+ if isinstance(response, dict) and 'data' in response:
200
+ all_models.extend(response['data'])
201
+ else:
202
+ logger.warning(f"Unexpected response format from model list API for key {api_key}: {response}")
203
+
204
+ if error_messages and not all_models:
205
+ raise HTTPException(status_code=500, detail=f"Errors encountered: {', '.join(error_messages)}")
206
+
207
+ return {"data": all_models, "object": "list"}
208
+
209
+ @app.post("/v1/chat/completions")
210
+ @app.post("/hf/v1/chat/completions")
211
+ async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
212
+ token = await verify_authorization(authorization)
213
+ api_keys = [key.strip() for key in token.split(',')]
214
+ logger.info(f"Chat completion request - Model: {request.model}")
215
+
216
+ retries = 0
217
+
218
+ while retries < MAX_RETRIES:
219
+ api_key = random.choice(api_keys)
220
+ try:
221
+ logger.info(f"Attempt {retries + 1} with API key: {api_key}")
222
+
223
+ if api_key.startswith("sk-"):
224
+ client = openai.OpenAI(api_key=api_key)
225
+
226
+ if request.stream:
227
+ logger.info("Streaming response enabled")
228
+
229
+ async def generate():
230
+ try:
231
+ stream_response = client.chat.completions.create(
232
+ model=request.model,
233
+ messages=request.messages,
234
+ temperature=request.temperature,
235
+ stream=True,
236
+ )
237
+
238
+ for chunk in stream_response:
239
+ chunk_json = chunk.model_dump_json()
240
+ yield f"data: {chunk_json}\n\n"
241
+ yield "data: [DONE]\n\n"
242
+ except Exception as e:
243
+ logger.error(f"Stream error: {str(e)}")
244
+ raise
245
+
246
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
247
+
248
+ else:
249
+ response = client.chat.completions.create(
250
+ model=request.model,
251
+ messages=request.messages,
252
+ temperature=request.temperature,
253
+ )
254
+ logger.info("Chat completion successful")
255
+ return response.model_dump()
256
+ else:
257
+ gemini_messages = convert_messages_to_gemini_format(request.messages)
258
+ payload = {
259
+ "contents": gemini_messages,
260
+ "generationConfig": {
261
+ "temperature": request.temperature,
262
+ }
263
+ }
264
+
265
+ if request.stream:
266
+ logger.info("Streaming response enabled")
267
+
268
+ async def generate():
269
+ nonlocal api_key, retries, api_keys
270
+
271
+ while retries < MAX_RETRIES:
272
+ try:
273
+ async with httpx.AsyncClient() as client:
274
+ stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
275
+ async with client.stream("POST", stream_url, json=payload, timeout=60.0) as response:
276
+ if response.status_code == 429:
277
+ logger.warning(f"Rate limit reached for key: {api_key}")
278
+ retries += 1
279
+ if retries >= MAX_RETRIES:
280
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
281
+ break
282
+
283
+ api_keys.remove(api_key)
284
+ if not api_keys:
285
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
286
+ break
287
+
288
+ api_key = random.choice(api_keys)
289
+ logger.info(f"Retrying with a new API key: {api_key}")
290
+ continue
291
+
292
+ if response.status_code != 200:
293
+ logger.error(f"Error in streaming response with key {api_key}: {response.status_code} - {response.text}")
294
+
295
+ retries += 1
296
+ if retries >= MAX_RETRIES:
297
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
298
+ break
299
+
300
+ api_keys.remove(api_key)
301
+ if not api_keys:
302
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
303
+ break
304
+
305
+ api_key = random.choice(api_keys)
306
+ logger.info(f"Retrying with a new API key: {api_key}")
307
+ continue
308
+
309
+ async for line in response.aiter_lines():
310
+ if line.startswith("data: "):
311
+ try:
312
+ chunk = json.loads(line[6:])
313
+ if not chunk.get("candidates"):
314
+ continue
315
+
316
+ content = chunk["candidates"][0]["content"]["parts"][0]["text"]
317
+
318
+ new_chunk = {
319
+ "id": "chatcmpl-" + str(uuid.uuid4()),
320
+ "object": "chat.completion.chunk",
321
+ "created": int(time.time()),
322
+ "model": request.model,
323
+ "choices": [
324
+ {
325
+ "index": 0,
326
+ "delta": {
327
+ "content": content
328
+ },
329
+ "finish_reason": None,
330
+ }
331
+ ],
332
+ }
333
+ yield f"data: {json.dumps(new_chunk)}\n\n"
334
+
335
+ except json.JSONDecodeError:
336
+ continue
337
+ yield "data: [DONE]\n\n"
338
+ return
339
+ except Exception as e:
340
+ logger.error(f"Stream error: {str(e)}")
341
+ retries += 1
342
+ if retries >= MAX_RETRIES:
343
+ yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
344
+ break
345
+
346
+ api_keys.remove(api_key)
347
+ if not api_keys:
348
+ yield f"data: {json.dumps({'error': 'All API keys exhausted'})}\n\n"
349
+ break
350
+
351
+ api_key = random.choice(api_keys)
352
+ logger.info(f"Retrying with a new API key: {api_key}")
353
+ continue
354
+
355
+ return StreamingResponse(content=generate(), media_type="text/event-stream")
356
+ else:
357
+ async with httpx.AsyncClient() as client:
358
+ non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
359
+ response = await client.post(non_stream_url, json=payload)
360
+
361
+ if response.status_code != 200:
362
+ logger.error(f"Error in non-streaming response with key {api_key}: {response.status_code} - {response.text}")
363
+
364
+ retries += 1
365
+ if retries >= MAX_RETRIES:
366
+ raise HTTPException(status_code=500, detail="Max retries reached")
367
+
368
+ api_keys.remove(api_key)
369
+ if not api_keys:
370
+ raise HTTPException(status_code=500, detail="All API keys exhausted")
371
+
372
+ api_key = random.choice(api_keys)
373
+ logger.info(f"Retrying with a new API key: {api_key}")
374
+ continue
375
+
376
+ gemini_response = response.json()
377
+ logger.info("Chat completion successful")
378
+ return await convert_gemini_response_to_openai(gemini_response, request.model)
379
+
380
+ except Exception as e:
381
+ logger.error(f"Error in chat completion: {str(e)}")
382
+ if isinstance(e, HTTPException):
383
+ raise e
384
+
385
+ retries += 1
386
+ if retries >= MAX_RETRIES:
387
+ logger.error("Max retries reached, giving up")
388
+ raise HTTPException(status_code=500, detail="Max retries reached")
389
+
390
+ api_keys.remove(api_key)
391
+ if not api_keys:
392
+ raise HTTPException(status_code=500, detail="All API keys exhausted")
393
+
394
+ api_key = random.choice(api_keys)
395
+ logger.info(f"Retrying with a new API key: {api_key}")
396
+ continue
397
+
398
+ raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
399
 
 
 
 
 
 
400
 
401
  @app.get("/health")
402
  @app.get("/")
 
405
  return {"status": "healthy"}
406
 
407
  if __name__ == "__main__":
408
+ uvicorn.run(app, host="0.0.0.0", port=8080)