Spaces:
Sleeping
Sleeping
Commit
·
99fddd6
1
Parent(s):
7b41d4c
Initial commit of openai-proxy
Browse files- Dockerfile +21 -0
- LICENSE +21 -0
- README.md +5 -9
- log.py +52 -0
- main.py +80 -0
- pyproject.toml +16 -0
- requirements.txt +6 -0
- utils.py +102 -0
- utils_test.py +62 -0
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Copy the requirements file into the container at /app
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
# Install any needed packages specified in requirements.txt
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Copy the rest of the application's code into the container at /app
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
# Back4app will set the PORT environment variable, but we expose 8000 as a default
|
| 17 |
+
EXPOSE 8000
|
| 18 |
+
|
| 19 |
+
# Run uvicorn when the container launches
|
| 20 |
+
# Use 0.0.0.0 to be accessible from outside the container
|
| 21 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 fangwentong
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: green
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
short_description: openai-proxy
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: OpenAI Proxy
|
| 3 |
+
emoji: 🚀
|
| 4 |
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
log.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import databases
|
| 2 |
+
|
| 3 |
+
from sqlalchemy import create_engine, Column, Integer, String, BigInteger
|
| 4 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 5 |
+
from sqlalchemy.orm import sessionmaker
|
| 6 |
+
|
| 7 |
+
# database config
|
| 8 |
+
DATABASE_URL = 'sqlite:///./openai_log.db'
|
| 9 |
+
database = databases.Database(DATABASE_URL)
|
| 10 |
+
Base = declarative_base()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# database model
|
| 14 |
+
class OpenAILog(Base):
|
| 15 |
+
"""
|
| 16 |
+
OpenAI API call log
|
| 17 |
+
"""
|
| 18 |
+
__tablename__ = 'openai_logs'
|
| 19 |
+
|
| 20 |
+
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
| 21 |
+
request_url = Column(String)
|
| 22 |
+
request_method = Column(String)
|
| 23 |
+
request_time = Column(BigInteger)
|
| 24 |
+
response_time = Column(BigInteger)
|
| 25 |
+
status_code = Column(Integer)
|
| 26 |
+
request_content = Column(String)
|
| 27 |
+
response_header = Column(String)
|
| 28 |
+
response_content = Column(String)
|
| 29 |
+
|
| 30 |
+
def to_dict(self):
|
| 31 |
+
return {
|
| 32 |
+
'id': self.id,
|
| 33 |
+
'request_url': self.request_url,
|
| 34 |
+
'request_method': self.request_method,
|
| 35 |
+
'request_time': self.request_time,
|
| 36 |
+
'response_time': self.response_time,
|
| 37 |
+
'status_code': self.status_code,
|
| 38 |
+
'request_content': self.request_content,
|
| 39 |
+
'response_header': self.response_header,
|
| 40 |
+
'response_content': self.response_content,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
engine = create_engine(DATABASE_URL)
|
| 45 |
+
Base.metadata.create_all(bind=engine)
|
| 46 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
async def save_log(log: OpenAILog):
|
| 50 |
+
async with database.transaction():
|
| 51 |
+
query = OpenAILog.__table__.insert().values(**log.to_dict())
|
| 52 |
+
await database.execute(query)
|
main.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
from fastapi import FastAPI, Request, HTTPException
|
| 8 |
+
from starlette.background import BackgroundTask
|
| 9 |
+
|
| 10 |
+
from log import OpenAILog, save_log
|
| 11 |
+
from utils import PathMatchingTree, OverrideStreamResponse
|
| 12 |
+
|
| 13 |
+
proxied_hosts = PathMatchingTree({
|
| 14 |
+
"/": "https://api.openai.com",
|
| 15 |
+
"/backend-api/conversation": "https://chat.openai.com",
|
| 16 |
+
})
|
| 17 |
+
|
| 18 |
+
# FastAPI app
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def proxy_openai_api(request: Request):
|
| 23 |
+
# proxy request to OpenAI API
|
| 24 |
+
headers = {k: v for k, v in request.headers.items() if
|
| 25 |
+
k not in {'host', 'content-length', 'x-forwarded-for', 'x-real-ip', 'connection'}}
|
| 26 |
+
url = f'{proxied_hosts.get_matching(request.url.path)}{request.url.path}'
|
| 27 |
+
|
| 28 |
+
start_time = datetime.now().microsecond
|
| 29 |
+
# create httpx async client
|
| 30 |
+
client = httpx.AsyncClient()
|
| 31 |
+
|
| 32 |
+
request_body = await request.json() if request.method in {'POST', 'PUT'} else None
|
| 33 |
+
|
| 34 |
+
log = OpenAILog()
|
| 35 |
+
|
| 36 |
+
async def stream_api_response():
|
| 37 |
+
nonlocal log
|
| 38 |
+
try:
|
| 39 |
+
st = client.stream(request.method, url, headers=headers, params=request.query_params, json=request_body)
|
| 40 |
+
async with st as res:
|
| 41 |
+
response.status_code = res.status_code
|
| 42 |
+
response.init_headers({k: v for k, v in res.headers.items() if
|
| 43 |
+
k not in {'content-length', 'content-encoding', 'alt-svc'}})
|
| 44 |
+
|
| 45 |
+
content = bytearray()
|
| 46 |
+
async for chunk in res.aiter_bytes():
|
| 47 |
+
yield chunk
|
| 48 |
+
content.extend(chunk)
|
| 49 |
+
|
| 50 |
+
# gather log data
|
| 51 |
+
log.request_url = url
|
| 52 |
+
log.request_method = request.method
|
| 53 |
+
log.request_time = start_time
|
| 54 |
+
log.response_time = time.time() - start_time
|
| 55 |
+
log.status_code = res.status_code
|
| 56 |
+
log.request_content = (await request.body()).decode('utf-8') if request.method == 'POST' else None
|
| 57 |
+
log.response_content = content.decode('utf-8')
|
| 58 |
+
log.response_header = json.dumps([[k, v] for k, v in res.headers.items()])
|
| 59 |
+
|
| 60 |
+
except httpx.RequestError as exc:
|
| 61 |
+
raise HTTPException(status_code=500, detail=f'An error occurred while requesting: {exc}')
|
| 62 |
+
|
| 63 |
+
async def update_log():
|
| 64 |
+
nonlocal log
|
| 65 |
+
log.response_time = datetime.now().microsecond - start_time
|
| 66 |
+
await save_log(log)
|
| 67 |
+
|
| 68 |
+
response = OverrideStreamResponse(stream_api_response(), background=BackgroundTask(update_log))
|
| 69 |
+
return response
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.route('/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
| 73 |
+
async def request_handler(request: Request):
|
| 74 |
+
return await proxy_openai_api(request)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
import uvicorn
|
| 79 |
+
|
| 80 |
+
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info", reload=True)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "openai-gateway"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = ""
|
| 5 |
+
authors = ["fangwentong <fangwentong2012@gmail.com>"]
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
packages = [{ include = "openai_gateway" }]
|
| 8 |
+
|
| 9 |
+
[tool.poetry.dependencies]
|
| 10 |
+
python = "^3.10"
|
| 11 |
+
databases = "^0.7.0"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
[build-system]
|
| 15 |
+
requires = ["poetry-core"]
|
| 16 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.74.0
|
| 2 |
+
httpx==0.23.0
|
| 3 |
+
sqlalchemy~=1.4.47
|
| 4 |
+
databases~=0.7.0
|
| 5 |
+
uvicorn==0.15.0
|
| 6 |
+
aiosqlite==0.18.0
|
utils.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing
|
| 2 |
+
from functools import partial
|
| 3 |
+
|
| 4 |
+
import anyio
|
| 5 |
+
from fastapi.responses import StreamingResponse
|
| 6 |
+
from starlette.types import Send, Scope, Receive
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PathMatchingTree:
|
| 10 |
+
"""
|
| 11 |
+
PathMatchingTree is a data structure that can be used to match a path with a value.
|
| 12 |
+
It supports exact match, partial match, and wildcard match.
|
| 13 |
+
For example, if the tree is built with the following config:
|
| 14 |
+
{
|
| 15 |
+
"/foo/bar": "value1",
|
| 16 |
+
"/baz/qux": "value2",
|
| 17 |
+
"/foo/*": "value3",
|
| 18 |
+
"/foo/*/bar": "value4"
|
| 19 |
+
}
|
| 20 |
+
Then the following path will match the corresponding value:
|
| 21 |
+
/foo/bar -> value1
|
| 22 |
+
/baz/qux -> value2
|
| 23 |
+
/foo/baz -> value3
|
| 24 |
+
/foo/baz/bar -> value4
|
| 25 |
+
/foo/baz/bar2 -> value3
|
| 26 |
+
"""
|
| 27 |
+
child = dict
|
| 28 |
+
value = None
|
| 29 |
+
|
| 30 |
+
def __init__(self, config):
|
| 31 |
+
self.child = {}
|
| 32 |
+
self._build_tree(config)
|
| 33 |
+
|
| 34 |
+
def _build_tree(self, config):
|
| 35 |
+
for k, v in config.items():
|
| 36 |
+
parts = k.split('/')
|
| 37 |
+
self._add(parts, v)
|
| 38 |
+
|
| 39 |
+
def _add(self, parts, value):
|
| 40 |
+
node = self
|
| 41 |
+
for part in parts:
|
| 42 |
+
if part == '':
|
| 43 |
+
continue
|
| 44 |
+
if part not in node.child:
|
| 45 |
+
node.child[part] = PathMatchingTree(dict())
|
| 46 |
+
node = node.child[part]
|
| 47 |
+
node.value = value
|
| 48 |
+
|
| 49 |
+
def get_matching(self, path):
|
| 50 |
+
parts = path.split('/')
|
| 51 |
+
matched = self
|
| 52 |
+
for part in parts:
|
| 53 |
+
if part == '':
|
| 54 |
+
continue
|
| 55 |
+
if part in matched.child:
|
| 56 |
+
matched = matched.child[part]
|
| 57 |
+
elif '*' in matched.child:
|
| 58 |
+
matched = matched.child['*']
|
| 59 |
+
else:
|
| 60 |
+
break
|
| 61 |
+
return matched.value
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class OverrideStreamResponse(StreamingResponse):
|
| 65 |
+
"""
|
| 66 |
+
Override StreamingResponse to support lazy send response status_code and response headers
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
async def stream_response(self, send: Send) -> None:
|
| 70 |
+
first_chunk = True
|
| 71 |
+
async for chunk in self.body_iterator:
|
| 72 |
+
if first_chunk:
|
| 73 |
+
await self.send_request_header(send)
|
| 74 |
+
first_chunk = False
|
| 75 |
+
if not isinstance(chunk, bytes):
|
| 76 |
+
chunk = chunk.encode(self.charset)
|
| 77 |
+
await send({'type': 'http.response.body', 'body': chunk, 'more_body': True})
|
| 78 |
+
|
| 79 |
+
if first_chunk:
|
| 80 |
+
await self.send_request_header(send)
|
| 81 |
+
await send({'type': 'http.response.body', 'body': b'', 'more_body': False})
|
| 82 |
+
|
| 83 |
+
async def send_request_header(self, send: Send) -> None:
|
| 84 |
+
await send(
|
| 85 |
+
{
|
| 86 |
+
'type': 'http.response.start',
|
| 87 |
+
'status': self.status_code,
|
| 88 |
+
'headers': self.raw_headers,
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
| 93 |
+
async with anyio.create_task_group() as task_group:
|
| 94 |
+
async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
|
| 95 |
+
await func()
|
| 96 |
+
await task_group.cancel_scope.cancel()
|
| 97 |
+
|
| 98 |
+
task_group.start_soon(wrap, partial(self.stream_response, send))
|
| 99 |
+
await wrap(partial(self.listen_for_disconnect, receive))
|
| 100 |
+
|
| 101 |
+
if self.background is not None:
|
| 102 |
+
await self.background()
|
utils_test.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
from utils import PathMatchingTree
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TestPathMatchingTree(unittest.TestCase):
|
| 6 |
+
|
| 7 |
+
def test_get_matching_exact_match(self):
|
| 8 |
+
config = {
|
| 9 |
+
"foo/bar": "value1",
|
| 10 |
+
"baz/qux": "value2"
|
| 11 |
+
}
|
| 12 |
+
pmt = PathMatchingTree(config)
|
| 13 |
+
result = pmt.get_matching("foo/bar")
|
| 14 |
+
self.assertEqual(result, "value1")
|
| 15 |
+
|
| 16 |
+
def test_get_matching_partial_match(self):
|
| 17 |
+
config = {
|
| 18 |
+
"foo/bar": "value1",
|
| 19 |
+
"baz/qux": "value2"
|
| 20 |
+
}
|
| 21 |
+
pmt = PathMatchingTree(config)
|
| 22 |
+
self.assertIsNone(pmt.get_matching("foo"))
|
| 23 |
+
|
| 24 |
+
def test_get_matching_wildcard_match(self):
|
| 25 |
+
config = {
|
| 26 |
+
"/foo/*": "value1",
|
| 27 |
+
"/baz/qux": "value2"
|
| 28 |
+
}
|
| 29 |
+
pmt = PathMatchingTree(config)
|
| 30 |
+
self.assertEqual(pmt.get_matching("foo/bar"), "value1")
|
| 31 |
+
|
| 32 |
+
def test_get_matching_multiple_wildcard_match(self):
|
| 33 |
+
config = {
|
| 34 |
+
"/foo/*": "value1",
|
| 35 |
+
"/foo/*/bar": "value2"
|
| 36 |
+
}
|
| 37 |
+
pmt = PathMatchingTree(config)
|
| 38 |
+
self.assertIsNone(pmt.get_matching("/foo"))
|
| 39 |
+
self.assertEqual(pmt.get_matching("/foo/baz"), "value1")
|
| 40 |
+
self.assertEqual(pmt.get_matching("/foo/baz/bar2"), "value1")
|
| 41 |
+
self.assertEqual(pmt.get_matching("/foo/baz/bar"), "value2")
|
| 42 |
+
|
| 43 |
+
def test_get_matching_no_match(self):
|
| 44 |
+
config = {
|
| 45 |
+
"/foo/bar": "value1",
|
| 46 |
+
"/baz/qux": "value2"
|
| 47 |
+
}
|
| 48 |
+
pmt = PathMatchingTree(config)
|
| 49 |
+
self.assertIsNone(pmt.get_matching("/foo"))
|
| 50 |
+
self.assertIsNone(pmt.get_matching("/baz"))
|
| 51 |
+
|
| 52 |
+
def test_get_matching_empty_string_match(self):
|
| 53 |
+
config = {
|
| 54 |
+
"/": "value1"
|
| 55 |
+
}
|
| 56 |
+
pmt = PathMatchingTree(config)
|
| 57 |
+
self.assertEqual(pmt.get_matching("/"), "value1")
|
| 58 |
+
self.assertEqual(pmt.get_matching("/test"), "value1")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
unittest.main()
|