AdarshJi commited on
Commit
3ef5e4f
·
verified ·
1 Parent(s): b0316b9

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +101 -80
server.py CHANGED
@@ -534,10 +534,6 @@ def QWEN(
534
 
535
 
536
 
537
-
538
-
539
-
540
-
541
 
542
 
543
 
@@ -549,7 +545,7 @@ PROVIDERS: Dict[str, Dict[str, Any]] = {
549
  "3": {"__func__": FREEGPT, "models": M3},
550
  }
551
 
552
- # will be filled on startup to avoid per-request introspection
553
  PROVIDER_META: Dict[str, Dict[str, Any]] = {}
554
 
555
  class Config:
@@ -577,7 +573,7 @@ class ChatRequest:
577
  messages = payload.get("messages") or payload.get("message") or payload.get("msgs")
578
  model = payload.get("model_name") or payload.get("model")
579
  provider = payload.get("provider") or Config.DEFAULT_PROVIDER
580
- provider = str(provider) # keep "1","2","3" style
581
  max_tokens = payload.get("max_tokens", Config.DEFAULT_MAX_TOKENS)
582
  temperature = payload.get("temperature", Config.DEFAULT_TEMPERATURE)
583
  stream = payload.get("stream", Config.STREAM)
@@ -602,31 +598,78 @@ GLOBAL_AIOHTTP: Optional[aiohttp.ClientSession] = None
602
  @app.on_event("startup")
603
  async def on_startup():
604
  global GLOBAL_AIOHTTP, PROVIDER_META
605
- logger.info("Starting up - creating global aiohttp session and analyzing providers")
606
  GLOBAL_AIOHTTP = aiohttp.ClientSession()
607
  for key, payload in PROVIDERS.items():
608
  func = payload["__func__"]
609
- meta = {
610
  "func": func,
611
  "is_async_gen_fn": inspect.isasyncgenfunction(func),
612
  "is_coroutine_fn": inspect.iscoroutinefunction(func),
613
  "is_generator_fn": inspect.isgeneratorfunction(func),
614
- # mark as sync if not coroutine/asyncgen/generator
615
- "is_sync": not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) or inspect.isgeneratorfunction(func)),
616
  }
617
- PROVIDER_META[key] = meta
618
- logger.info("Provider metadata prepared")
619
 
620
 
621
  @app.on_event("shutdown")
622
  async def on_shutdown():
623
  global GLOBAL_AIOHTTP
624
- logger.info("Shutting down - closing global aiohttp session")
625
  if GLOBAL_AIOHTTP and not GLOBAL_AIOHTTP.closed:
626
  await GLOBAL_AIOHTTP.close()
627
 
628
 
629
- async def _call_provider_and_iterate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  provider_key: str,
631
  messages: List[Dict],
632
  model: str,
@@ -635,8 +678,7 @@ async def _call_provider_and_iterate(
635
  timeout: float,
636
  ) -> AsyncGenerator[bytes, None]:
637
  """
638
- Invoke provider according to metadata and yield raw bytes.
639
- We'll transform these bytes into SSE events higher up.
640
  """
641
  if provider_key not in PROVIDER_META:
642
  raise ValueError(f"Unknown provider '{provider_key}'")
@@ -644,15 +686,14 @@ async def _call_provider_and_iterate(
644
  meta = PROVIDER_META[provider_key]
645
  func = meta["func"]
646
 
647
- async def _invoke_async():
648
- return func(Requests, Message=messages, Model=model, max_token=max_token, stream=stream_flag, timeout=timeout)
649
 
650
  try:
651
- provider_task = _invoke_async()
652
-
653
- # async generator function
654
  if meta["is_async_gen_fn"]:
655
- agen = await asyncio.wait_for(provider_task, timeout=timeout)
 
656
  async for item in agen:
657
  if item is None:
658
  continue
@@ -664,14 +705,22 @@ async def _call_provider_and_iterate(
664
  yield str(item).encode("utf-8")
665
  return
666
 
667
- # coroutine function
 
 
 
 
 
 
 
668
  if meta["is_coroutine_fn"]:
669
- res = await asyncio.wait_for(provider_task, timeout=timeout)
 
670
  if res is None:
671
  return
672
- # list/tuple
673
- if isinstance(res, (list, tuple)):
674
- for item in res:
675
  if item is None:
676
  continue
677
  if isinstance(item, bytes):
@@ -681,7 +730,7 @@ async def _call_provider_and_iterate(
681
  else:
682
  yield str(item).encode("utf-8")
683
  return
684
- # sync-iterable
685
  if inspect.isgenerator(res) or (hasattr(res, "__iter__") and not isinstance(res, (str, bytes, dict))):
686
  for item in res:
687
  if item is None:
@@ -702,25 +751,14 @@ async def _call_provider_and_iterate(
702
  yield str(res).encode("utf-8")
703
  return
704
 
705
- # sync function/generator: run in thread
706
- sync_res = await asyncio.wait_for(
707
- asyncio.to_thread(func, Requests, messages, model, max_token, stream_flag, timeout),
708
- timeout=timeout,
709
- )
710
 
 
711
  if sync_res is None:
712
  return
713
- if isinstance(sync_res, (list, tuple)):
714
- for item in sync_res:
715
- if item is None:
716
- continue
717
- if isinstance(item, bytes):
718
- yield item
719
- elif isinstance(item, str):
720
- yield item.encode("utf-8")
721
- else:
722
- yield str(item).encode("utf-8")
723
- return
724
  if inspect.isgenerator(sync_res) or (hasattr(sync_res, "__iter__") and not isinstance(sync_res, (str, bytes, dict))):
725
  for item in sync_res:
726
  if item is None:
@@ -744,13 +782,14 @@ async def _call_provider_and_iterate(
744
  logger.warning(err.strip())
745
  yield err.encode("utf-8")
746
  except Exception as e:
747
- logger.exception("Provider error")
748
  err = f"[server_error] {type(e).__name__}: {e}\n"
749
  yield err.encode("utf-8")
750
 
751
 
752
  @app.post("/chat")
753
  async def chat_endpoint(request: Request):
 
754
  try:
755
  body_bytes = await request.body()
756
  payload = _loads(body_bytes)
@@ -764,14 +803,9 @@ async def chat_endpoint(request: Request):
764
  provider_key = req.provider
765
 
766
  if req.stream:
767
- async def sse_stream_gen():
768
- """
769
- For every chunk from provider, send an SSE event line:
770
- data: {"response":"..."}\n\n
771
- After completion send a final line: [DONE]\n
772
- """
773
- # iterate provider outputs (raw bytes)
774
- async for raw_chunk in _call_provider_and_iterate(
775
  provider_key=provider_key,
776
  messages=req.messages,
777
  model=req.model or Config.DEFAULT_MODEL,
@@ -779,36 +813,28 @@ async def chat_endpoint(request: Request):
779
  stream_flag=req.stream,
780
  timeout=Config.TIMEOUT,
781
  ):
782
- # decode provider chunk to text
783
- if isinstance(raw_chunk, bytes):
784
- text = raw_chunk.decode("utf-8", errors="ignore")
785
- else:
786
- text = str(raw_chunk)
787
-
788
- # build the JSON payload {"response": "<text>"} and serialize
789
  payload_obj = {"response": text}
790
  try:
791
  json_str = _dumps(payload_obj)
792
  except Exception:
793
- # fallback to manual safe-escape for string-only payload
794
- import json as _json_fallback
795
- json_str = _json_fallback.dumps(payload_obj)
796
-
797
- # SSE data line + double newline (SSE event terminator)
798
  sse_event = f"data: {json_str}\n\n"
799
  yield sse_event.encode("utf-8")
800
-
801
- # final termination marker exactly as you requested
802
- # NOTE: sending it as a line by itself (not prefixed by 'data:')
803
  yield ("[DONE]\n").encode("utf-8")
804
 
805
- # content-type text/event-stream (SSE)
806
- return StreamingResponse(sse_stream_gen(), media_type="text/event-stream")
807
 
808
  else:
809
- # non-stream: collect and return JSON (same as before)
810
  collected = []
811
- async for chunk in _call_provider_and_iterate(
812
  provider_key=provider_key,
813
  messages=req.messages,
814
  model=req.model or Config.DEFAULT_MODEL,
@@ -816,18 +842,13 @@ async def chat_endpoint(request: Request):
816
  stream_flag=req.stream,
817
  timeout=Config.TIMEOUT,
818
  ):
819
- if isinstance(chunk, bytes):
820
- collected.append(chunk.decode("utf-8", errors="ignore"))
821
- else:
822
- collected.append(str(chunk))
823
- full_text = "".join(collected)
824
- return JSONResponse({"text": full_text})
825
 
826
 
827
  @app.get("/model")
828
  async def model():
829
- models = [M1, M2, M3]
830
- return {"models": models}
831
 
832
 
833
  @app.get("/health")
 
534
 
535
 
536
 
 
 
 
 
537
 
538
 
539
 
 
545
  "3": {"__func__": FREEGPT, "models": M3},
546
  }
547
 
548
+ # precomputed provider metadata for speed
549
  PROVIDER_META: Dict[str, Dict[str, Any]] = {}
550
 
551
  class Config:
 
573
  messages = payload.get("messages") or payload.get("message") or payload.get("msgs")
574
  model = payload.get("model_name") or payload.get("model")
575
  provider = payload.get("provider") or Config.DEFAULT_PROVIDER
576
+ provider = str(provider)
577
  max_tokens = payload.get("max_tokens", Config.DEFAULT_MAX_TOKENS)
578
  temperature = payload.get("temperature", Config.DEFAULT_TEMPERATURE)
579
  stream = payload.get("stream", Config.STREAM)
 
598
  @app.on_event("startup")
599
  async def on_startup():
600
  global GLOBAL_AIOHTTP, PROVIDER_META
601
+ logger.info("startup: create global aiohttp session and analyze providers")
602
  GLOBAL_AIOHTTP = aiohttp.ClientSession()
603
  for key, payload in PROVIDERS.items():
604
  func = payload["__func__"]
605
+ PROVIDER_META[key] = {
606
  "func": func,
607
  "is_async_gen_fn": inspect.isasyncgenfunction(func),
608
  "is_coroutine_fn": inspect.iscoroutinefunction(func),
609
  "is_generator_fn": inspect.isgeneratorfunction(func),
610
+ "is_sync_fn": not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) or inspect.isgeneratorfunction(func)),
 
611
  }
612
+ logger.info("provider meta ready: %s", {k: {kk: vv for kk, vv in v.items() if kk != "func"} for k, v in PROVIDER_META.items()})
 
613
 
614
 
615
  @app.on_event("shutdown")
616
  async def on_shutdown():
617
  global GLOBAL_AIOHTTP
618
+ logger.info("shutdown: close global aiohttp session")
619
  if GLOBAL_AIOHTTP and not GLOBAL_AIOHTTP.closed:
620
  await GLOBAL_AIOHTTP.close()
621
 
622
 
623
+ async def _stream_sync_generator_in_thread(func, *args, **kwargs) -> AsyncGenerator[bytes, None]:
624
+ """
625
+ Run a sync generator in a thread and stream items back via an asyncio.Queue.
626
+ This allows streaming without blocking the event loop.
627
+ """
628
+ loop = asyncio.get_running_loop()
629
+ q: asyncio.Queue = asyncio.Queue(maxsize=32)
630
+ sentinel = object()
631
+
632
+ def worker():
633
+ try:
634
+ gen = func(*args, **kwargs)
635
+ # if the function is not actually a generator but returns a value, handle that
636
+ if gen is None:
637
+ loop.call_soon_threadsafe(q.put_nowait, sentinel)
638
+ return
639
+ # If it's iterable, iterate and put items into queue
640
+ for item in gen:
641
+ loop.call_soon_threadsafe(q.put_nowait, item)
642
+ except Exception as e:
643
+ # pass the exception object forward to the async side
644
+ loop.call_soon_threadsafe(q.put_nowait, e)
645
+ finally:
646
+ loop.call_soon_threadsafe(q.put_nowait, sentinel)
647
+
648
+ # start worker in thread
649
+ thread_task = loop.run_in_executor(None, worker)
650
+
651
+ # consume from queue
652
+ while True:
653
+ item = await q.get()
654
+ if item is sentinel:
655
+ break
656
+ if isinstance(item, Exception):
657
+ # re-raise in async context so upstream can handle
658
+ raise item
659
+ if item is None:
660
+ continue
661
+ if isinstance(item, bytes):
662
+ yield item
663
+ elif isinstance(item, str):
664
+ yield item.encode("utf-8")
665
+ else:
666
+ yield str(item).encode("utf-8")
667
+
668
+ # ensure worker finished/propagated exceptions
669
+ await asyncio.shield(thread_task)
670
+
671
+
672
+ async def _call_provider_and_stream(
673
  provider_key: str,
674
  messages: List[Dict],
675
  model: str,
 
678
  timeout: float,
679
  ) -> AsyncGenerator[bytes, None]:
680
  """
681
+ Core streaming logic. Yields raw bytes as soon as provider yields items.
 
682
  """
683
  if provider_key not in PROVIDER_META:
684
  raise ValueError(f"Unknown provider '{provider_key}'")
 
686
  meta = PROVIDER_META[provider_key]
687
  func = meta["func"]
688
 
689
+ # pass arguments using your original parameter names so providers stay unchanged
690
+ kwargs = dict(messages=messages, model=model, max_token=max_token, stream=stream_flag, timeout=timeout)
691
 
692
  try:
693
+ # 1) Async generator functions -> call returns an async generator (do NOT await)
 
 
694
  if meta["is_async_gen_fn"]:
695
+ agen = func(Requests, **kwargs)
696
+ # iterate immediately (no waiting for full result)
697
  async for item in agen:
698
  if item is None:
699
  continue
 
705
  yield str(item).encode("utf-8")
706
  return
707
 
708
+ # 2) Sync generator functions -> call returns generator; iterate it in background thread
709
+ if meta["is_generator_fn"]:
710
+ # Note: call func in thread via helper which will iterate and push items to queue
711
+ async for item in _stream_sync_generator_in_thread(lambda *a, **k: func(Requests, **kwargs)):
712
+ yield item
713
+ return
714
+
715
+ # 3) Coroutine functions (async def) that return final result -> await it (can't stream before it completes)
716
  if meta["is_coroutine_fn"]:
717
+ # await the coroutine under timeout (can't stream until it returns)
718
+ res = await asyncio.wait_for(func(Requests, **kwargs), timeout=timeout)
719
  if res is None:
720
  return
721
+ # if it returned an async generator (rare), iterate it
722
+ if inspect.isasyncgen(res):
723
+ async for item in res:
724
  if item is None:
725
  continue
726
  if isinstance(item, bytes):
 
730
  else:
731
  yield str(item).encode("utf-8")
732
  return
733
+ # if it returned a sync iterable -> iterate and yield
734
  if inspect.isgenerator(res) or (hasattr(res, "__iter__") and not isinstance(res, (str, bytes, dict))):
735
  for item in res:
736
  if item is None:
 
751
  yield str(res).encode("utf-8")
752
  return
753
 
754
+ # 4) Sync plain function (not generator) -> run in thread (returns value or iterable)
755
+ # We call func in a thread and stream results as they appear if it's iterable.
756
+ def sync_call_wrapper():
757
+ return func(Requests, **kwargs)
 
758
 
759
+ sync_res = await asyncio.wait_for(asyncio.to_thread(sync_call_wrapper), timeout=timeout)
760
  if sync_res is None:
761
  return
 
 
 
 
 
 
 
 
 
 
 
762
  if inspect.isgenerator(sync_res) or (hasattr(sync_res, "__iter__") and not isinstance(sync_res, (str, bytes, dict))):
763
  for item in sync_res:
764
  if item is None:
 
782
  logger.warning(err.strip())
783
  yield err.encode("utf-8")
784
  except Exception as e:
785
+ logger.exception("provider error")
786
  err = f"[server_error] {type(e).__name__}: {e}\n"
787
  yield err.encode("utf-8")
788
 
789
 
790
  @app.post("/chat")
791
  async def chat_endpoint(request: Request):
792
+ # fast load
793
  try:
794
  body_bytes = await request.body()
795
  payload = _loads(body_bytes)
 
803
  provider_key = req.provider
804
 
805
  if req.stream:
806
+ async def sse_stream():
807
+ # iterate provider stream and immediately send SSE-formatted chunks
808
+ async for raw_chunk in _call_provider_and_stream(
 
 
 
 
 
809
  provider_key=provider_key,
810
  messages=req.messages,
811
  model=req.model or Config.DEFAULT_MODEL,
 
813
  stream_flag=req.stream,
814
  timeout=Config.TIMEOUT,
815
  ):
816
+ # decode raw chunk to text
817
+ text = raw_chunk.decode("utf-8", errors="ignore") if isinstance(raw_chunk, (bytes, bytearray)) else str(raw_chunk)
818
+ # prepare JSON payload object
 
 
 
 
819
  payload_obj = {"response": text}
820
  try:
821
  json_str = _dumps(payload_obj)
822
  except Exception:
823
+ # fallback
824
+ import json as _fallback_json
825
+ json_str = _fallback_json.dumps(payload_obj)
826
+ # send SSE data line + blank line
 
827
  sse_event = f"data: {json_str}\n\n"
828
  yield sse_event.encode("utf-8")
829
+ # final termination marker exactly as requested
 
 
830
  yield ("[DONE]\n").encode("utf-8")
831
 
832
+ return StreamingResponse(sse_stream(), media_type="text/event-stream")
 
833
 
834
  else:
835
+ # non-stream: collect (only for non-stream requests)
836
  collected = []
837
+ async for chunk in _call_provider_and_stream(
838
  provider_key=provider_key,
839
  messages=req.messages,
840
  model=req.model or Config.DEFAULT_MODEL,
 
842
  stream_flag=req.stream,
843
  timeout=Config.TIMEOUT,
844
  ):
845
+ collected.append(chunk.decode("utf-8", errors="ignore") if isinstance(chunk, (bytes, bytearray)) else str(chunk))
846
+ return JSONResponse({"text": "".join(collected)})
 
 
 
 
847
 
848
 
849
  @app.get("/model")
850
  async def model():
851
+ return {"models": [M1, M2, M3]}
 
852
 
853
 
854
  @app.get("/health")