igor04091968 commited on
Commit
99fddd6
·
1 Parent(s): 7b41d4c

Initial commit of openai-proxy

Browse files
Files changed (9) hide show
  1. Dockerfile +21 -0
  2. LICENSE +21 -0
  3. README.md +5 -9
  4. log.py +52 -0
  5. main.py +80 -0
  6. pyproject.toml +16 -0
  7. requirements.txt +6 -0
  8. utils.py +102 -0
  9. utils_test.py +62 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.11-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file into the container at /app
8
+ COPY requirements.txt .
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the rest of the application's code into the container at /app
14
+ COPY . .
15
+
16
+ # Back4app will set the PORT environment variable, but we expose 8000 as a default
17
+ EXPOSE 8000
18
+
19
+ # Run uvicorn when the container launches
20
+ # Use 0.0.0.0 to be accessible from outside the container
21
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 fangwentong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
- title: Openai Proxy
3
- emoji: 🌖
4
  colorFrom: green
5
- colorTo: green
6
  sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: openai-proxy
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: OpenAI Proxy
3
+ emoji: 🚀
4
  colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
+ app_port: 8000
8
+ ---
 
 
 
 
log.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import databases
2
+
3
+ from sqlalchemy import create_engine, Column, Integer, String, BigInteger
4
+ from sqlalchemy.ext.declarative import declarative_base
5
+ from sqlalchemy.orm import sessionmaker
6
+
7
+ # database config
8
+ DATABASE_URL = 'sqlite:///./openai_log.db'
9
+ database = databases.Database(DATABASE_URL)
10
+ Base = declarative_base()
11
+
12
+
13
+ # database model
14
+ class OpenAILog(Base):
15
+ """
16
+ OpenAI API call log
17
+ """
18
+ __tablename__ = 'openai_logs'
19
+
20
+ id = Column(Integer, primary_key=True, index=True, autoincrement=True)
21
+ request_url = Column(String)
22
+ request_method = Column(String)
23
+ request_time = Column(BigInteger)
24
+ response_time = Column(BigInteger)
25
+ status_code = Column(Integer)
26
+ request_content = Column(String)
27
+ response_header = Column(String)
28
+ response_content = Column(String)
29
+
30
+ def to_dict(self):
31
+ return {
32
+ 'id': self.id,
33
+ 'request_url': self.request_url,
34
+ 'request_method': self.request_method,
35
+ 'request_time': self.request_time,
36
+ 'response_time': self.response_time,
37
+ 'status_code': self.status_code,
38
+ 'request_content': self.request_content,
39
+ 'response_header': self.response_header,
40
+ 'response_content': self.response_content,
41
+ }
42
+
43
+
44
+ engine = create_engine(DATABASE_URL)
45
+ Base.metadata.create_all(bind=engine)
46
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
47
+
48
+
49
+ async def save_log(log: OpenAILog):
50
+ async with database.transaction():
51
+ query = OpenAILog.__table__.insert().values(**log.to_dict())
52
+ await database.execute(query)
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ import time
4
+ from datetime import datetime
5
+
6
+ import httpx
7
+ from fastapi import FastAPI, Request, HTTPException
8
+ from starlette.background import BackgroundTask
9
+
10
+ from log import OpenAILog, save_log
11
+ from utils import PathMatchingTree, OverrideStreamResponse
12
+
13
+ proxied_hosts = PathMatchingTree({
14
+ "/": "https://api.openai.com",
15
+ "/backend-api/conversation": "https://chat.openai.com",
16
+ })
17
+
18
+ # FastAPI app
19
+ app = FastAPI()
20
+
21
+
22
+ async def proxy_openai_api(request: Request):
23
+ # proxy request to OpenAI API
24
+ headers = {k: v for k, v in request.headers.items() if
25
+ k not in {'host', 'content-length', 'x-forwarded-for', 'x-real-ip', 'connection'}}
26
+ url = f'{proxied_hosts.get_matching(request.url.path)}{request.url.path}'
27
+
28
+ start_time = datetime.now().microsecond
29
+ # create httpx async client
30
+ client = httpx.AsyncClient()
31
+
32
+ request_body = await request.json() if request.method in {'POST', 'PUT'} else None
33
+
34
+ log = OpenAILog()
35
+
36
+ async def stream_api_response():
37
+ nonlocal log
38
+ try:
39
+ st = client.stream(request.method, url, headers=headers, params=request.query_params, json=request_body)
40
+ async with st as res:
41
+ response.status_code = res.status_code
42
+ response.init_headers({k: v for k, v in res.headers.items() if
43
+ k not in {'content-length', 'content-encoding', 'alt-svc'}})
44
+
45
+ content = bytearray()
46
+ async for chunk in res.aiter_bytes():
47
+ yield chunk
48
+ content.extend(chunk)
49
+
50
+ # gather log data
51
+ log.request_url = url
52
+ log.request_method = request.method
53
+ log.request_time = start_time
54
+ log.response_time = time.time() - start_time
55
+ log.status_code = res.status_code
56
+ log.request_content = (await request.body()).decode('utf-8') if request.method == 'POST' else None
57
+ log.response_content = content.decode('utf-8')
58
+ log.response_header = json.dumps([[k, v] for k, v in res.headers.items()])
59
+
60
+ except httpx.RequestError as exc:
61
+ raise HTTPException(status_code=500, detail=f'An error occurred while requesting: {exc}')
62
+
63
+ async def update_log():
64
+ nonlocal log
65
+ log.response_time = datetime.now().microsecond - start_time
66
+ await save_log(log)
67
+
68
+ response = OverrideStreamResponse(stream_api_response(), background=BackgroundTask(update_log))
69
+ return response
70
+
71
+
72
+ @app.route('/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
73
+ async def request_handler(request: Request):
74
+ return await proxy_openai_api(request)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import uvicorn
79
+
80
+ uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info", reload=True)
pyproject.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "openai-gateway"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["fangwentong <fangwentong2012@gmail.com>"]
6
+ readme = "README.md"
7
+ packages = [{ include = "openai_gateway" }]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.10"
11
+ databases = "^0.7.0"
12
+
13
+
14
+ [build-system]
15
+ requires = ["poetry-core"]
16
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.74.0
2
+ httpx==0.23.0
3
+ sqlalchemy~=1.4.47
4
+ databases~=0.7.0
5
+ uvicorn==0.15.0
6
+ aiosqlite==0.18.0
utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from functools import partial
3
+
4
+ import anyio
5
+ from fastapi.responses import StreamingResponse
6
+ from starlette.types import Send, Scope, Receive
7
+
8
+
9
+ class PathMatchingTree:
10
+ """
11
+ PathMatchingTree is a data structure that can be used to match a path with a value.
12
+ It supports exact match, partial match, and wildcard match.
13
+ For example, if the tree is built with the following config:
14
+ {
15
+ "/foo/bar": "value1",
16
+ "/baz/qux": "value2",
17
+ "/foo/*": "value3",
18
+ "/foo/*/bar": "value4"
19
+ }
20
+ Then the following path will match the corresponding value:
21
+ /foo/bar -> value1
22
+ /baz/qux -> value2
23
+ /foo/baz -> value3
24
+ /foo/baz/bar -> value4
25
+ /foo/baz/bar2 -> value3
26
+ """
27
+ child = dict
28
+ value = None
29
+
30
+ def __init__(self, config):
31
+ self.child = {}
32
+ self._build_tree(config)
33
+
34
+ def _build_tree(self, config):
35
+ for k, v in config.items():
36
+ parts = k.split('/')
37
+ self._add(parts, v)
38
+
39
+ def _add(self, parts, value):
40
+ node = self
41
+ for part in parts:
42
+ if part == '':
43
+ continue
44
+ if part not in node.child:
45
+ node.child[part] = PathMatchingTree(dict())
46
+ node = node.child[part]
47
+ node.value = value
48
+
49
+ def get_matching(self, path):
50
+ parts = path.split('/')
51
+ matched = self
52
+ for part in parts:
53
+ if part == '':
54
+ continue
55
+ if part in matched.child:
56
+ matched = matched.child[part]
57
+ elif '*' in matched.child:
58
+ matched = matched.child['*']
59
+ else:
60
+ break
61
+ return matched.value
62
+
63
+
64
+ class OverrideStreamResponse(StreamingResponse):
65
+ """
66
+ Override StreamingResponse to support lazy send response status_code and response headers
67
+ """
68
+
69
+ async def stream_response(self, send: Send) -> None:
70
+ first_chunk = True
71
+ async for chunk in self.body_iterator:
72
+ if first_chunk:
73
+ await self.send_request_header(send)
74
+ first_chunk = False
75
+ if not isinstance(chunk, bytes):
76
+ chunk = chunk.encode(self.charset)
77
+ await send({'type': 'http.response.body', 'body': chunk, 'more_body': True})
78
+
79
+ if first_chunk:
80
+ await self.send_request_header(send)
81
+ await send({'type': 'http.response.body', 'body': b'', 'more_body': False})
82
+
83
+ async def send_request_header(self, send: Send) -> None:
84
+ await send(
85
+ {
86
+ 'type': 'http.response.start',
87
+ 'status': self.status_code,
88
+ 'headers': self.raw_headers,
89
+ }
90
+ )
91
+
92
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
93
+ async with anyio.create_task_group() as task_group:
94
+ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
95
+ await func()
96
+ await task_group.cancel_scope.cancel()
97
+
98
+ task_group.start_soon(wrap, partial(self.stream_response, send))
99
+ await wrap(partial(self.listen_for_disconnect, receive))
100
+
101
+ if self.background is not None:
102
+ await self.background()
utils_test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from utils import PathMatchingTree
3
+
4
+
5
+ class TestPathMatchingTree(unittest.TestCase):
6
+
7
+ def test_get_matching_exact_match(self):
8
+ config = {
9
+ "foo/bar": "value1",
10
+ "baz/qux": "value2"
11
+ }
12
+ pmt = PathMatchingTree(config)
13
+ result = pmt.get_matching("foo/bar")
14
+ self.assertEqual(result, "value1")
15
+
16
+ def test_get_matching_partial_match(self):
17
+ config = {
18
+ "foo/bar": "value1",
19
+ "baz/qux": "value2"
20
+ }
21
+ pmt = PathMatchingTree(config)
22
+ self.assertIsNone(pmt.get_matching("foo"))
23
+
24
+ def test_get_matching_wildcard_match(self):
25
+ config = {
26
+ "/foo/*": "value1",
27
+ "/baz/qux": "value2"
28
+ }
29
+ pmt = PathMatchingTree(config)
30
+ self.assertEqual(pmt.get_matching("foo/bar"), "value1")
31
+
32
+ def test_get_matching_multiple_wildcard_match(self):
33
+ config = {
34
+ "/foo/*": "value1",
35
+ "/foo/*/bar": "value2"
36
+ }
37
+ pmt = PathMatchingTree(config)
38
+ self.assertIsNone(pmt.get_matching("/foo"))
39
+ self.assertEqual(pmt.get_matching("/foo/baz"), "value1")
40
+ self.assertEqual(pmt.get_matching("/foo/baz/bar2"), "value1")
41
+ self.assertEqual(pmt.get_matching("/foo/baz/bar"), "value2")
42
+
43
+ def test_get_matching_no_match(self):
44
+ config = {
45
+ "/foo/bar": "value1",
46
+ "/baz/qux": "value2"
47
+ }
48
+ pmt = PathMatchingTree(config)
49
+ self.assertIsNone(pmt.get_matching("/foo"))
50
+ self.assertIsNone(pmt.get_matching("/baz"))
51
+
52
+ def test_get_matching_empty_string_match(self):
53
+ config = {
54
+ "/": "value1"
55
+ }
56
+ pmt = PathMatchingTree(config)
57
+ self.assertEqual(pmt.get_matching("/"), "value1")
58
+ self.assertEqual(pmt.get_matching("/test"), "value1")
59
+
60
+
61
+ if __name__ == "__main__":
62
+ unittest.main()