Spaces:
Sleeping
Sleeping
Commit
·
a415c67
1
Parent(s):
ab063cf
minor integration adjustments
Browse files- middlewares.py +22 -24
- server.py +7 -9
- validators/__init__.py +2 -2
- validators/sn1_validator_wrapper.py +4 -5
middlewares.py
CHANGED
|
@@ -1,34 +1,32 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import bittensor as bt
|
| 4 |
-
from aiohttp.web import Response
|
| 5 |
|
| 6 |
EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
return middleware_handler
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
return middleware_handler
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import bittensor as bt
|
| 4 |
+
from aiohttp.web import Request, Response, middleware
|
| 5 |
|
| 6 |
EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
|
| 7 |
|
| 8 |
+
@middleware
|
| 9 |
+
async def api_key_middleware(request: Request, handler):
|
| 10 |
+
# Logging the request
|
| 11 |
+
bt.logging.info(f"Handling {request.method} request to {request.path}")
|
| 12 |
|
| 13 |
+
# Check access key
|
| 14 |
+
access_key = request.headers.get("api_key")
|
| 15 |
+
if EXPECTED_ACCESS_KEY is not None and access_key != EXPECTED_ACCESS_KEY:
|
| 16 |
+
bt.logging.error(f'Invalid access key: {access_key}')
|
| 17 |
+
return Response(status=401, reason="Invalid access key")
|
| 18 |
|
| 19 |
+
# Continue to the next handler if the API key is valid
|
| 20 |
+
return await handler(request)
|
|
|
|
| 21 |
|
| 22 |
+
@middleware
|
| 23 |
+
async def json_parsing_middleware(request: Request, handler):
|
| 24 |
+
try:
|
| 25 |
+
# Parsing JSON data from the request
|
| 26 |
+
request['data'] = await request.json()
|
| 27 |
+
except json.JSONDecodeError as e:
|
| 28 |
+
bt.logging.error(f'Invalid JSON data: {str(e)}')
|
| 29 |
+
return Response(status=400, text="Invalid JSON")
|
| 30 |
|
| 31 |
+
# Continue to the next handler if JSON is successfully parsed
|
| 32 |
+
return await handler(request)
|
|
|
server.py
CHANGED
|
@@ -34,8 +34,6 @@ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.n
|
|
| 34 |
```
|
| 35 |
add --mock to test the echo stream
|
| 36 |
"""
|
| 37 |
-
@api_key_middleware
|
| 38 |
-
@json_parsing_middleware
|
| 39 |
async def chat(request: web.Request) -> Response:
|
| 40 |
"""
|
| 41 |
Chat endpoint for the validator.
|
|
@@ -43,7 +41,7 @@ async def chat(request: web.Request) -> Response:
|
|
| 43 |
request_data = request['data']
|
| 44 |
params = QueryValidatorParams.from_dict(request_data)
|
| 45 |
# TODO: SET STREAM AS DEFAULT
|
| 46 |
-
stream = request_data.get('stream',
|
| 47 |
|
| 48 |
# Access the validator from the application context
|
| 49 |
validator: ValidatorAPI = request.app['validator']
|
|
@@ -52,29 +50,29 @@ async def chat(request: web.Request) -> Response:
|
|
| 52 |
return response
|
| 53 |
|
| 54 |
|
| 55 |
-
@api_key_middleware
|
| 56 |
-
@json_parsing_middleware
|
| 57 |
async def echo_stream(request, request_data):
|
| 58 |
request_data = request['data']
|
| 59 |
return await utils.echo_stream(request_data)
|
| 60 |
|
| 61 |
|
|
|
|
| 62 |
class ValidatorApplication(web.Application):
|
| 63 |
def __init__(self, validator_instance=None, *args, **kwargs):
|
| 64 |
super().__init__(*args, **kwargs)
|
| 65 |
|
| 66 |
self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
|
| 67 |
|
| 68 |
-
# Add middlewares to application
|
| 69 |
-
self.middlewares.append(api_key_middleware)
|
| 70 |
-
self.middlewares.append(json_parsing_middleware)
|
| 71 |
-
|
| 72 |
self.add_routes([
|
| 73 |
web.post('/chat/', chat),
|
| 74 |
web.post('/echo/', echo_stream)
|
| 75 |
])
|
|
|
|
| 76 |
# TODO: Enable rewarding and other features
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def main(run_aio_app=True, test=False) -> None:
|
| 80 |
loop = asyncio.get_event_loop()
|
|
|
|
| 34 |
```
|
| 35 |
add --mock to test the echo stream
|
| 36 |
"""
|
|
|
|
|
|
|
| 37 |
async def chat(request: web.Request) -> Response:
|
| 38 |
"""
|
| 39 |
Chat endpoint for the validator.
|
|
|
|
| 41 |
request_data = request['data']
|
| 42 |
params = QueryValidatorParams.from_dict(request_data)
|
| 43 |
# TODO: SET STREAM AS DEFAULT
|
| 44 |
+
stream = request_data.get('stream', True)
|
| 45 |
|
| 46 |
# Access the validator from the application context
|
| 47 |
validator: ValidatorAPI = request.app['validator']
|
|
|
|
| 50 |
return response
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
| 53 |
async def echo_stream(request, request_data):
|
| 54 |
request_data = request['data']
|
| 55 |
return await utils.echo_stream(request_data)
|
| 56 |
|
| 57 |
|
| 58 |
+
|
| 59 |
class ValidatorApplication(web.Application):
|
| 60 |
def __init__(self, validator_instance=None, *args, **kwargs):
|
| 61 |
super().__init__(*args, **kwargs)
|
| 62 |
|
| 63 |
self['validator'] = validator_instance if validator_instance else S1ValidatorAPI()
|
| 64 |
|
| 65 |
+
# Add middlewares to application
|
|
|
|
|
|
|
|
|
|
| 66 |
self.add_routes([
|
| 67 |
web.post('/chat/', chat),
|
| 68 |
web.post('/echo/', echo_stream)
|
| 69 |
])
|
| 70 |
+
self.setup_middlewares()
|
| 71 |
# TODO: Enable rewarding and other features
|
| 72 |
|
| 73 |
+
def setup_middlewares(self):
|
| 74 |
+
self.middlewares.append(json_parsing_middleware)
|
| 75 |
+
self.middlewares.append(api_key_middleware)
|
| 76 |
|
| 77 |
def main(run_aio_app=True, test=False) -> None:
|
| 78 |
loop = asyncio.get_event_loop()
|
validators/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
from base import QueryValidatorParams, ValidatorAPI, MockValidator
|
| 2 |
-
from sn1_validator_wrapper import S1ValidatorAPI
|
|
|
|
| 1 |
+
from .base import QueryValidatorParams, ValidatorAPI, MockValidator
|
| 2 |
+
from .sn1_validator_wrapper import S1ValidatorAPI
|
validators/sn1_validator_wrapper.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
import json
|
| 2 |
import utils
|
|
|
|
| 3 |
import traceback
|
| 4 |
import bittensor as bt
|
| 5 |
-
import asyncio
|
| 6 |
-
from prompting.forward import handle_response
|
| 7 |
from prompting.validator import Validator
|
| 8 |
from prompting.utils.uids import get_random_uids
|
| 9 |
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
|
| 10 |
from prompting.dendrite import DendriteResponseEvent
|
| 11 |
-
from base import QueryValidatorParams, ValidatorAPI
|
| 12 |
from aiohttp.web_response import Response, StreamResponse
|
| 13 |
from deprecated import deprecated
|
| 14 |
|
|
@@ -16,7 +15,7 @@ class S1ValidatorAPI(ValidatorAPI):
|
|
| 16 |
def __init__(self):
|
| 17 |
self.validator = Validator()
|
| 18 |
|
| 19 |
-
|
| 20 |
@deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
|
| 21 |
async def get_response(self, params:QueryValidatorParams) -> Response:
|
| 22 |
try:
|
|
@@ -37,7 +36,7 @@ class S1ValidatorAPI(ValidatorAPI):
|
|
| 37 |
|
| 38 |
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
|
| 39 |
# Encapsulate the responses in a response event (dataclass)
|
| 40 |
-
response_event = DendriteResponseEvent(responses, uids)
|
| 41 |
|
| 42 |
# convert dict to json
|
| 43 |
response = response_event.__state_dict__()
|
|
|
|
| 1 |
import json
|
| 2 |
import utils
|
| 3 |
+
import torch
|
| 4 |
import traceback
|
| 5 |
import bittensor as bt
|
|
|
|
|
|
|
| 6 |
from prompting.validator import Validator
|
| 7 |
from prompting.utils.uids import get_random_uids
|
| 8 |
from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
|
| 9 |
from prompting.dendrite import DendriteResponseEvent
|
| 10 |
+
from .base import QueryValidatorParams, ValidatorAPI
|
| 11 |
from aiohttp.web_response import Response, StreamResponse
|
| 12 |
from deprecated import deprecated
|
| 13 |
|
|
|
|
| 15 |
def __init__(self):
|
| 16 |
self.validator = Validator()
|
| 17 |
|
| 18 |
+
|
| 19 |
@deprecated(reason="This function is deprecated. Validators use stream synapse now, use get_stream_response instead.")
|
| 20 |
async def get_response(self, params:QueryValidatorParams) -> Response:
|
| 21 |
try:
|
|
|
|
| 36 |
|
| 37 |
bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
|
| 38 |
# Encapsulate the responses in a response event (dataclass)
|
| 39 |
+
response_event = DendriteResponseEvent(responses, torch.LongTensor(uids), params.timeout)
|
| 40 |
|
| 41 |
# convert dict to json
|
| 42 |
response = response_event.__state_dict__()
|