LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import uvicorn
from fastapi import FastAPI
logger = logging.getLogger(__file__)
def get_max_position_embeddings(hf_config) -> int:
max_len = getattr(hf_config, "max_position_embeddings", None)
if max_len is None:
text_config = getattr(hf_config, "text_config", None)
if text_config is not None:
max_len = getattr(text_config, "max_position_embeddings", None)
if max_len is None:
raise ValueError("max_position_embeddings not found in HFModelConfig!")
return int(max_len)
class _UvicornServerAutoPort(uvicorn.Server):
"""Uvicorn Server that reports the system-assigned port when port=0."""
def __init__(self, config: uvicorn.Config) -> None:
super().__init__(config)
self.actual_port: int | None = None
self._startup_done: asyncio.Event = asyncio.Event()
async def startup(self, sockets=None) -> None:
try:
await super().startup(sockets=sockets)
if self.servers and self.config.port == 0:
sock = self.servers[0].sockets[0]
self.actual_port = sock.getsockname()[1]
else:
self.actual_port = self.config.port
finally:
self._startup_done.set()
async def get_port(self) -> int | None:
await self._startup_done.wait()
return self.actual_port
async def run_uvicorn(app: FastAPI, server_args, server_address) -> tuple[int, asyncio.Task]:
app.server_args = server_args
config = uvicorn.Config(app, host=server_address, port=0, log_level="warning")
server = _UvicornServerAutoPort(config)
server_task = asyncio.create_task(server.serve())
server_port = await server.get_port()
if server_port is None:
# server.startup() failed. await the task to re-raise exception from server.serve()
await server_task
# Fails on unexpected situation.
raise RuntimeError("Unexpected: HTTP server started without reporting listened port")
logger.info(f"HTTP server started on port {server_port}")
return server_port, server_task
async def ensure_async_iterator(iterable):
"""Convert an iterable to an async iterator."""
if hasattr(iterable, "__aiter__"):
async for item in iterable:
yield item
else:
for item in iterable:
yield item