Update server.py
Browse files
server.py
CHANGED
|
@@ -534,10 +534,6 @@ def QWEN(
|
|
| 534 |
|
| 535 |
|
| 536 |
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
|
| 542 |
|
| 543 |
|
|
@@ -549,7 +545,7 @@ PROVIDERS: Dict[str, Dict[str, Any]] = {
|
|
| 549 |
"3": {"__func__": FREEGPT, "models": M3},
|
| 550 |
}
|
| 551 |
|
| 552 |
-
#
|
| 553 |
PROVIDER_META: Dict[str, Dict[str, Any]] = {}
|
| 554 |
|
| 555 |
class Config:
|
|
@@ -577,7 +573,7 @@ class ChatRequest:
|
|
| 577 |
messages = payload.get("messages") or payload.get("message") or payload.get("msgs")
|
| 578 |
model = payload.get("model_name") or payload.get("model")
|
| 579 |
provider = payload.get("provider") or Config.DEFAULT_PROVIDER
|
| 580 |
-
provider = str(provider)
|
| 581 |
max_tokens = payload.get("max_tokens", Config.DEFAULT_MAX_TOKENS)
|
| 582 |
temperature = payload.get("temperature", Config.DEFAULT_TEMPERATURE)
|
| 583 |
stream = payload.get("stream", Config.STREAM)
|
|
@@ -602,31 +598,78 @@ GLOBAL_AIOHTTP: Optional[aiohttp.ClientSession] = None
|
|
| 602 |
@app.on_event("startup")
|
| 603 |
async def on_startup():
|
| 604 |
global GLOBAL_AIOHTTP, PROVIDER_META
|
| 605 |
-
logger.info("
|
| 606 |
GLOBAL_AIOHTTP = aiohttp.ClientSession()
|
| 607 |
for key, payload in PROVIDERS.items():
|
| 608 |
func = payload["__func__"]
|
| 609 |
-
|
| 610 |
"func": func,
|
| 611 |
"is_async_gen_fn": inspect.isasyncgenfunction(func),
|
| 612 |
"is_coroutine_fn": inspect.iscoroutinefunction(func),
|
| 613 |
"is_generator_fn": inspect.isgeneratorfunction(func),
|
| 614 |
-
|
| 615 |
-
"is_sync": not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) or inspect.isgeneratorfunction(func)),
|
| 616 |
}
|
| 617 |
-
|
| 618 |
-
logger.info("Provider metadata prepared")
|
| 619 |
|
| 620 |
|
| 621 |
@app.on_event("shutdown")
|
| 622 |
async def on_shutdown():
|
| 623 |
global GLOBAL_AIOHTTP
|
| 624 |
-
logger.info("
|
| 625 |
if GLOBAL_AIOHTTP and not GLOBAL_AIOHTTP.closed:
|
| 626 |
await GLOBAL_AIOHTTP.close()
|
| 627 |
|
| 628 |
|
| 629 |
-
async def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
provider_key: str,
|
| 631 |
messages: List[Dict],
|
| 632 |
model: str,
|
|
@@ -635,8 +678,7 @@ async def _call_provider_and_iterate(
|
|
| 635 |
timeout: float,
|
| 636 |
) -> AsyncGenerator[bytes, None]:
|
| 637 |
"""
|
| 638 |
-
|
| 639 |
-
We'll transform these bytes into SSE events higher up.
|
| 640 |
"""
|
| 641 |
if provider_key not in PROVIDER_META:
|
| 642 |
raise ValueError(f"Unknown provider '{provider_key}'")
|
|
@@ -644,15 +686,14 @@ async def _call_provider_and_iterate(
|
|
| 644 |
meta = PROVIDER_META[provider_key]
|
| 645 |
func = meta["func"]
|
| 646 |
|
| 647 |
-
|
| 648 |
-
|
| 649 |
|
| 650 |
try:
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
# async generator function
|
| 654 |
if meta["is_async_gen_fn"]:
|
| 655 |
-
agen =
|
|
|
|
| 656 |
async for item in agen:
|
| 657 |
if item is None:
|
| 658 |
continue
|
|
@@ -664,14 +705,22 @@ async def _call_provider_and_iterate(
|
|
| 664 |
yield str(item).encode("utf-8")
|
| 665 |
return
|
| 666 |
|
| 667 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
if meta["is_coroutine_fn"]:
|
| 669 |
-
|
|
|
|
| 670 |
if res is None:
|
| 671 |
return
|
| 672 |
-
#
|
| 673 |
-
if
|
| 674 |
-
for item in res:
|
| 675 |
if item is None:
|
| 676 |
continue
|
| 677 |
if isinstance(item, bytes):
|
|
@@ -681,7 +730,7 @@ async def _call_provider_and_iterate(
|
|
| 681 |
else:
|
| 682 |
yield str(item).encode("utf-8")
|
| 683 |
return
|
| 684 |
-
# sync
|
| 685 |
if inspect.isgenerator(res) or (hasattr(res, "__iter__") and not isinstance(res, (str, bytes, dict))):
|
| 686 |
for item in res:
|
| 687 |
if item is None:
|
|
@@ -702,25 +751,14 @@ async def _call_provider_and_iterate(
|
|
| 702 |
yield str(res).encode("utf-8")
|
| 703 |
return
|
| 704 |
|
| 705 |
-
#
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
)
|
| 710 |
|
|
|
|
| 711 |
if sync_res is None:
|
| 712 |
return
|
| 713 |
-
if isinstance(sync_res, (list, tuple)):
|
| 714 |
-
for item in sync_res:
|
| 715 |
-
if item is None:
|
| 716 |
-
continue
|
| 717 |
-
if isinstance(item, bytes):
|
| 718 |
-
yield item
|
| 719 |
-
elif isinstance(item, str):
|
| 720 |
-
yield item.encode("utf-8")
|
| 721 |
-
else:
|
| 722 |
-
yield str(item).encode("utf-8")
|
| 723 |
-
return
|
| 724 |
if inspect.isgenerator(sync_res) or (hasattr(sync_res, "__iter__") and not isinstance(sync_res, (str, bytes, dict))):
|
| 725 |
for item in sync_res:
|
| 726 |
if item is None:
|
|
@@ -744,13 +782,14 @@ async def _call_provider_and_iterate(
|
|
| 744 |
logger.warning(err.strip())
|
| 745 |
yield err.encode("utf-8")
|
| 746 |
except Exception as e:
|
| 747 |
-
logger.exception("
|
| 748 |
err = f"[server_error] {type(e).__name__}: {e}\n"
|
| 749 |
yield err.encode("utf-8")
|
| 750 |
|
| 751 |
|
| 752 |
@app.post("/chat")
|
| 753 |
async def chat_endpoint(request: Request):
|
|
|
|
| 754 |
try:
|
| 755 |
body_bytes = await request.body()
|
| 756 |
payload = _loads(body_bytes)
|
|
@@ -764,14 +803,9 @@ async def chat_endpoint(request: Request):
|
|
| 764 |
provider_key = req.provider
|
| 765 |
|
| 766 |
if req.stream:
|
| 767 |
-
async def
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
data: {"response":"..."}\n\n
|
| 771 |
-
After completion send a final line: [DONE]\n
|
| 772 |
-
"""
|
| 773 |
-
# iterate provider outputs (raw bytes)
|
| 774 |
-
async for raw_chunk in _call_provider_and_iterate(
|
| 775 |
provider_key=provider_key,
|
| 776 |
messages=req.messages,
|
| 777 |
model=req.model or Config.DEFAULT_MODEL,
|
|
@@ -779,36 +813,28 @@ async def chat_endpoint(request: Request):
|
|
| 779 |
stream_flag=req.stream,
|
| 780 |
timeout=Config.TIMEOUT,
|
| 781 |
):
|
| 782 |
-
# decode
|
| 783 |
-
if isinstance(raw_chunk, bytes)
|
| 784 |
-
|
| 785 |
-
else:
|
| 786 |
-
text = str(raw_chunk)
|
| 787 |
-
|
| 788 |
-
# build the JSON payload {"response": "<text>"} and serialize
|
| 789 |
payload_obj = {"response": text}
|
| 790 |
try:
|
| 791 |
json_str = _dumps(payload_obj)
|
| 792 |
except Exception:
|
| 793 |
-
# fallback
|
| 794 |
-
import json as
|
| 795 |
-
json_str =
|
| 796 |
-
|
| 797 |
-
# SSE data line + double newline (SSE event terminator)
|
| 798 |
sse_event = f"data: {json_str}\n\n"
|
| 799 |
yield sse_event.encode("utf-8")
|
| 800 |
-
|
| 801 |
-
# final termination marker exactly as you requested
|
| 802 |
-
# NOTE: sending it as a line by itself (not prefixed by 'data:')
|
| 803 |
yield ("[DONE]\n").encode("utf-8")
|
| 804 |
|
| 805 |
-
|
| 806 |
-
return StreamingResponse(sse_stream_gen(), media_type="text/event-stream")
|
| 807 |
|
| 808 |
else:
|
| 809 |
-
# non-stream: collect
|
| 810 |
collected = []
|
| 811 |
-
async for chunk in
|
| 812 |
provider_key=provider_key,
|
| 813 |
messages=req.messages,
|
| 814 |
model=req.model or Config.DEFAULT_MODEL,
|
|
@@ -816,18 +842,13 @@ async def chat_endpoint(request: Request):
|
|
| 816 |
stream_flag=req.stream,
|
| 817 |
timeout=Config.TIMEOUT,
|
| 818 |
):
|
| 819 |
-
if isinstance(chunk, bytes)
|
| 820 |
-
|
| 821 |
-
else:
|
| 822 |
-
collected.append(str(chunk))
|
| 823 |
-
full_text = "".join(collected)
|
| 824 |
-
return JSONResponse({"text": full_text})
|
| 825 |
|
| 826 |
|
| 827 |
@app.get("/model")
|
| 828 |
async def model():
|
| 829 |
-
models
|
| 830 |
-
return {"models": models}
|
| 831 |
|
| 832 |
|
| 833 |
@app.get("/health")
|
|
|
|
| 534 |
|
| 535 |
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
|
| 539 |
|
|
|
|
| 545 |
"3": {"__func__": FREEGPT, "models": M3},
|
| 546 |
}
|
| 547 |
|
| 548 |
+
# precomputed provider metadata for speed
|
| 549 |
PROVIDER_META: Dict[str, Dict[str, Any]] = {}
|
| 550 |
|
| 551 |
class Config:
|
|
|
|
| 573 |
messages = payload.get("messages") or payload.get("message") or payload.get("msgs")
|
| 574 |
model = payload.get("model_name") or payload.get("model")
|
| 575 |
provider = payload.get("provider") or Config.DEFAULT_PROVIDER
|
| 576 |
+
provider = str(provider)
|
| 577 |
max_tokens = payload.get("max_tokens", Config.DEFAULT_MAX_TOKENS)
|
| 578 |
temperature = payload.get("temperature", Config.DEFAULT_TEMPERATURE)
|
| 579 |
stream = payload.get("stream", Config.STREAM)
|
|
|
|
| 598 |
@app.on_event("startup")
|
| 599 |
async def on_startup():
|
| 600 |
global GLOBAL_AIOHTTP, PROVIDER_META
|
| 601 |
+
logger.info("startup: create global aiohttp session and analyze providers")
|
| 602 |
GLOBAL_AIOHTTP = aiohttp.ClientSession()
|
| 603 |
for key, payload in PROVIDERS.items():
|
| 604 |
func = payload["__func__"]
|
| 605 |
+
PROVIDER_META[key] = {
|
| 606 |
"func": func,
|
| 607 |
"is_async_gen_fn": inspect.isasyncgenfunction(func),
|
| 608 |
"is_coroutine_fn": inspect.iscoroutinefunction(func),
|
| 609 |
"is_generator_fn": inspect.isgeneratorfunction(func),
|
| 610 |
+
"is_sync_fn": not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) or inspect.isgeneratorfunction(func)),
|
|
|
|
| 611 |
}
|
| 612 |
+
logger.info("provider meta ready: %s", {k: {kk: vv for kk, vv in v.items() if kk != "func"} for k, v in PROVIDER_META.items()})
|
|
|
|
| 613 |
|
| 614 |
|
| 615 |
@app.on_event("shutdown")
|
| 616 |
async def on_shutdown():
|
| 617 |
global GLOBAL_AIOHTTP
|
| 618 |
+
logger.info("shutdown: close global aiohttp session")
|
| 619 |
if GLOBAL_AIOHTTP and not GLOBAL_AIOHTTP.closed:
|
| 620 |
await GLOBAL_AIOHTTP.close()
|
| 621 |
|
| 622 |
|
| 623 |
+
async def _stream_sync_generator_in_thread(func, *args, **kwargs) -> AsyncGenerator[bytes, None]:
|
| 624 |
+
"""
|
| 625 |
+
Run a sync generator in a thread and stream items back via an asyncio.Queue.
|
| 626 |
+
This allows streaming without blocking the event loop.
|
| 627 |
+
"""
|
| 628 |
+
loop = asyncio.get_running_loop()
|
| 629 |
+
q: asyncio.Queue = asyncio.Queue(maxsize=32)
|
| 630 |
+
sentinel = object()
|
| 631 |
+
|
| 632 |
+
def worker():
|
| 633 |
+
try:
|
| 634 |
+
gen = func(*args, **kwargs)
|
| 635 |
+
# if the function is not actually a generator but returns a value, handle that
|
| 636 |
+
if gen is None:
|
| 637 |
+
loop.call_soon_threadsafe(q.put_nowait, sentinel)
|
| 638 |
+
return
|
| 639 |
+
# If it's iterable, iterate and put items into queue
|
| 640 |
+
for item in gen:
|
| 641 |
+
loop.call_soon_threadsafe(q.put_nowait, item)
|
| 642 |
+
except Exception as e:
|
| 643 |
+
# pass the exception object forward to the async side
|
| 644 |
+
loop.call_soon_threadsafe(q.put_nowait, e)
|
| 645 |
+
finally:
|
| 646 |
+
loop.call_soon_threadsafe(q.put_nowait, sentinel)
|
| 647 |
+
|
| 648 |
+
# start worker in thread
|
| 649 |
+
thread_task = loop.run_in_executor(None, worker)
|
| 650 |
+
|
| 651 |
+
# consume from queue
|
| 652 |
+
while True:
|
| 653 |
+
item = await q.get()
|
| 654 |
+
if item is sentinel:
|
| 655 |
+
break
|
| 656 |
+
if isinstance(item, Exception):
|
| 657 |
+
# re-raise in async context so upstream can handle
|
| 658 |
+
raise item
|
| 659 |
+
if item is None:
|
| 660 |
+
continue
|
| 661 |
+
if isinstance(item, bytes):
|
| 662 |
+
yield item
|
| 663 |
+
elif isinstance(item, str):
|
| 664 |
+
yield item.encode("utf-8")
|
| 665 |
+
else:
|
| 666 |
+
yield str(item).encode("utf-8")
|
| 667 |
+
|
| 668 |
+
# ensure worker finished/propagated exceptions
|
| 669 |
+
await asyncio.shield(thread_task)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
async def _call_provider_and_stream(
|
| 673 |
provider_key: str,
|
| 674 |
messages: List[Dict],
|
| 675 |
model: str,
|
|
|
|
| 678 |
timeout: float,
|
| 679 |
) -> AsyncGenerator[bytes, None]:
|
| 680 |
"""
|
| 681 |
+
Core streaming logic. Yields raw bytes as soon as provider yields items.
|
|
|
|
| 682 |
"""
|
| 683 |
if provider_key not in PROVIDER_META:
|
| 684 |
raise ValueError(f"Unknown provider '{provider_key}'")
|
|
|
|
| 686 |
meta = PROVIDER_META[provider_key]
|
| 687 |
func = meta["func"]
|
| 688 |
|
| 689 |
+
# pass arguments using your original parameter names so providers stay unchanged
|
| 690 |
+
kwargs = dict(messages=messages, model=model, max_token=max_token, stream=stream_flag, timeout=timeout)
|
| 691 |
|
| 692 |
try:
|
| 693 |
+
# 1) Async generator functions -> call returns an async generator (do NOT await)
|
|
|
|
|
|
|
| 694 |
if meta["is_async_gen_fn"]:
|
| 695 |
+
agen = func(Requests, **kwargs)
|
| 696 |
+
# iterate immediately (no waiting for full result)
|
| 697 |
async for item in agen:
|
| 698 |
if item is None:
|
| 699 |
continue
|
|
|
|
| 705 |
yield str(item).encode("utf-8")
|
| 706 |
return
|
| 707 |
|
| 708 |
+
# 2) Sync generator functions -> call returns generator; iterate it in background thread
|
| 709 |
+
if meta["is_generator_fn"]:
|
| 710 |
+
# Note: call func in thread via helper which will iterate and push items to queue
|
| 711 |
+
async for item in _stream_sync_generator_in_thread(lambda *a, **k: func(Requests, **kwargs)):
|
| 712 |
+
yield item
|
| 713 |
+
return
|
| 714 |
+
|
| 715 |
+
# 3) Coroutine functions (async def) that return final result -> await it (can't stream before it completes)
|
| 716 |
if meta["is_coroutine_fn"]:
|
| 717 |
+
# await the coroutine under timeout (can't stream until it returns)
|
| 718 |
+
res = await asyncio.wait_for(func(Requests, **kwargs), timeout=timeout)
|
| 719 |
if res is None:
|
| 720 |
return
|
| 721 |
+
# if it returned an async generator (rare), iterate it
|
| 722 |
+
if inspect.isasyncgen(res):
|
| 723 |
+
async for item in res:
|
| 724 |
if item is None:
|
| 725 |
continue
|
| 726 |
if isinstance(item, bytes):
|
|
|
|
| 730 |
else:
|
| 731 |
yield str(item).encode("utf-8")
|
| 732 |
return
|
| 733 |
+
# if it returned a sync iterable -> iterate and yield
|
| 734 |
if inspect.isgenerator(res) or (hasattr(res, "__iter__") and not isinstance(res, (str, bytes, dict))):
|
| 735 |
for item in res:
|
| 736 |
if item is None:
|
|
|
|
| 751 |
yield str(res).encode("utf-8")
|
| 752 |
return
|
| 753 |
|
| 754 |
+
# 4) Sync plain function (not generator) -> run in thread (returns value or iterable)
|
| 755 |
+
# We call func in a thread and stream results as they appear if it's iterable.
|
| 756 |
+
def sync_call_wrapper():
|
| 757 |
+
return func(Requests, **kwargs)
|
|
|
|
| 758 |
|
| 759 |
+
sync_res = await asyncio.wait_for(asyncio.to_thread(sync_call_wrapper), timeout=timeout)
|
| 760 |
if sync_res is None:
|
| 761 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
if inspect.isgenerator(sync_res) or (hasattr(sync_res, "__iter__") and not isinstance(sync_res, (str, bytes, dict))):
|
| 763 |
for item in sync_res:
|
| 764 |
if item is None:
|
|
|
|
| 782 |
logger.warning(err.strip())
|
| 783 |
yield err.encode("utf-8")
|
| 784 |
except Exception as e:
|
| 785 |
+
logger.exception("provider error")
|
| 786 |
err = f"[server_error] {type(e).__name__}: {e}\n"
|
| 787 |
yield err.encode("utf-8")
|
| 788 |
|
| 789 |
|
| 790 |
@app.post("/chat")
|
| 791 |
async def chat_endpoint(request: Request):
|
| 792 |
+
# fast load
|
| 793 |
try:
|
| 794 |
body_bytes = await request.body()
|
| 795 |
payload = _loads(body_bytes)
|
|
|
|
| 803 |
provider_key = req.provider
|
| 804 |
|
| 805 |
if req.stream:
|
| 806 |
+
async def sse_stream():
|
| 807 |
+
# iterate provider stream and immediately send SSE-formatted chunks
|
| 808 |
+
async for raw_chunk in _call_provider_and_stream(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
provider_key=provider_key,
|
| 810 |
messages=req.messages,
|
| 811 |
model=req.model or Config.DEFAULT_MODEL,
|
|
|
|
| 813 |
stream_flag=req.stream,
|
| 814 |
timeout=Config.TIMEOUT,
|
| 815 |
):
|
| 816 |
+
# decode raw chunk to text
|
| 817 |
+
text = raw_chunk.decode("utf-8", errors="ignore") if isinstance(raw_chunk, (bytes, bytearray)) else str(raw_chunk)
|
| 818 |
+
# prepare JSON payload object
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
payload_obj = {"response": text}
|
| 820 |
try:
|
| 821 |
json_str = _dumps(payload_obj)
|
| 822 |
except Exception:
|
| 823 |
+
# fallback
|
| 824 |
+
import json as _fallback_json
|
| 825 |
+
json_str = _fallback_json.dumps(payload_obj)
|
| 826 |
+
# send SSE data line + blank line
|
|
|
|
| 827 |
sse_event = f"data: {json_str}\n\n"
|
| 828 |
yield sse_event.encode("utf-8")
|
| 829 |
+
# final termination marker exactly as requested
|
|
|
|
|
|
|
| 830 |
yield ("[DONE]\n").encode("utf-8")
|
| 831 |
|
| 832 |
+
return StreamingResponse(sse_stream(), media_type="text/event-stream")
|
|
|
|
| 833 |
|
| 834 |
else:
|
| 835 |
+
# non-stream: collect (only for non-stream requests)
|
| 836 |
collected = []
|
| 837 |
+
async for chunk in _call_provider_and_stream(
|
| 838 |
provider_key=provider_key,
|
| 839 |
messages=req.messages,
|
| 840 |
model=req.model or Config.DEFAULT_MODEL,
|
|
|
|
| 842 |
stream_flag=req.stream,
|
| 843 |
timeout=Config.TIMEOUT,
|
| 844 |
):
|
| 845 |
+
collected.append(chunk.decode("utf-8", errors="ignore") if isinstance(chunk, (bytes, bytearray)) else str(chunk))
|
| 846 |
+
return JSONResponse({"text": "".join(collected)})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
|
| 849 |
@app.get("/model")
|
| 850 |
async def model():
|
| 851 |
+
return {"models": [M1, M2, M3]}
|
|
|
|
| 852 |
|
| 853 |
|
| 854 |
@app.get("/health")
|