Spaces:
Paused
Paused
| from asyncio import sleep | |
| from typing import Optional | |
| from fastapi import FastAPI | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.websockets import WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from websockets import ConnectionClosed | |
| from accelerator import Accelerator | |
| from answerer import Answerer | |
| from mapper import Mapper | |
| try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1") | |
| except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}") | |
| answerer = Answerer( | |
| model="RWKV-5-World-3B-v2-20231118-ctx16k", | |
| vocab="rwkv_vocab_v20230424", | |
| strategy="cpu bf16", | |
| ctx_limit=16*1024, | |
| ) | |
| accelerator = Accelerator() | |
| app = FastAPI() | |
| HTML = """ | |
| <!DOCTYPE HTML> | |
| <html> | |
| <body> | |
| <form action="" onsubmit="ask(event)"> | |
| <textarea id="prompt"></textarea> | |
| <br> | |
| <input type="submit" value="SEND" /> | |
| </form> | |
| <p id="output"></p> | |
| <script> | |
| const prompt = document.getElementById("prompt"); | |
| const output = document.getElementById("output"); | |
| const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer"); | |
| ws.onmessage = (e) => answer(e.data); | |
| function ask(event) { | |
| if(ws.readyState != 1) { | |
| answer("websocket is not connected!"); | |
| return; | |
| } | |
| ws.send(prompt.value); | |
| event.preventDefault(); | |
| } | |
| function answer(value) { | |
| output.innerHTML = value; | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| def index(): | |
| return HTMLResponse(HTML) | |
| async def answer(ws: WebSocket): | |
| await accelerator.connect(ws) | |
| while accelerator.connected(): | |
| await sleep(10) | |
| def map(query: Optional[str], items: Optional[list[str]]): | |
| scores = mapper(query, items) | |
| return JSONResponse(jsonable_encoder(scores)) | |
| async def handle_answerer_local(ws: WebSocket, input: str): | |
| output = answerer(input, 128) | |
| el: str | |
| async for el in output: pass | |
| await ws.send_text(el) | |
| async def handle_answerer_accelerated(ws: WebSocket, input: str): | |
| output = await accelerator.accelerate(input) | |
| if output: await ws.send_text(output) | |
| else: await handle_answerer_local(ws, input) | |
| async def answer(ws: WebSocket): | |
| await ws.accept() | |
| try: | |
| input = await ws.receive_text() | |
| if accelerator.connected(): await handle_answerer_accelerated(ws, input) | |
| else: await handle_answerer_local(ws, input) | |
| except ConnectionClosed: return | |
| except WebSocketDisconnect: return | |
| await ws.close() | |