multimodalart's picture
multimodalart HF Staff
Upload 75 files
a34bca4 verified
"""
"""
import traceback
import threading
import warnings
from pathlib import Path
from queue import Queue
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing_extensions import assert_type
from typing_extensions import Never
from uuid import uuid4
from fastapi import APIRouter
from fastapi import FastAPI
from fastapi import HTTPException
from uvicorn import Config
from uvicorn import Server
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=FutureWarning, message=r'.*torch\.distributed\.reduce_op.*')
from .._vendor.jurigged import codetools
from .._vendor.jurigged.codetools import CodeFile
from .._vendor.jurigged.codetools import CodeFileOperation
from .._vendor.jurigged.codetools import AddOperation
from .._vendor.jurigged.codetools import UpdateOperation
from .._vendor.jurigged.codetools import DeleteOperation
from .._vendor.jurigged.codetools import ExceptionOperation
from .._vendor.jurigged.codetools import Extent
from .._vendor.jurigged.codetools import LineDefinition
from .._vendor.jurigged.codetools import GroupDefinition
from .._vendor.jurigged.codetools import ModuleCode
from .._vendor.jurigged.codetools import ClassDefinition
from .._vendor.jurigged.codetools import FunctionDefinition
from .._vendor.jurigged.register import EventSource
from .._vendor.jurigged.register import Registry
from .._vendor.jurigged.utils import glob_filter
from .._vendor.sse_starlette import EventSourceResponse
from ..utils import create_thread
from .types import ApiCreateReloadRequest
from .types import ApiCreateReloadResponse
from .types import ApiCreateReloadResponseError
from .types import ApiCreateReloadResponseSuccess
from .types import ApiGetReloadRequest
from .types import ApiGetReloadEventSourceData
from .types import ApiGetStatusRequest
from .types import ApiGetStatusResponse
from .types import ApiFetchContentsRequest
from .types import ApiFetchContentsResponse
from .types import ApiFetchContentsResponseError
from .types import ApiFetchContentsResponseSuccess
from .types import ReloadRegion
from .types import ReloadOperationError
from .types import ReloadOperationException
from .types import ReloadOperationObject
from .types import ReloadOperationRun
from .types import ReloadOperationUI
class ReloadRun(NamedTuple):
activity: EventSource
thread: threading.Thread
class ReloadServer:
def __init__(
self,
prerun: Callable[[], Any],
postrun: Callable[[], bool],
stop_event: threading.Event,
):
self.prerun = prerun
self.postrun = postrun
self.stop_event = stop_event
self.create_lock = threading.Lock()
self.reload_runs: dict[str, ReloadRun] = {}
self.active_reload_id: str | None = None
self.router = APIRouter()
self.router.add_api_route('/healthz', lambda: {}, methods=['GET'])
self.router.add_api_route('/get-status', self.get_status, methods=['POST'])
self.router.add_api_route('/fetch-contents', self.fetch_contents, methods=['POST'])
self.router.add_api_route('/create-reload', self.create_reload, methods=['POST'])
self.router.add_api_route('/get-reload', self.get_reload, methods=['POST'])
self.registry = Registry()
self.registry.auto_register(filter=glob_filter('./*.py'))
def get_status(self, req: ApiGetStatusRequest) -> ApiGetStatusResponse:
"""
POST /get-status
"""
raise NotImplementedError # pragma: no cover
def fetch_contents(self, req: ApiFetchContentsRequest) -> ApiFetchContentsResponse:
"""
POST /fetch-contents
"""
filepath = (Path.cwd() / req.filepath).resolve()
if not filepath.is_relative_to(Path.cwd()) or not (file := Path(req.filepath)).is_file():
res = ApiFetchContentsResponseError(status='fileNotFound')
return ApiFetchContentsResponse(res=res)
res = ApiFetchContentsResponseSuccess(status='ok', contents=file.read_text())
return ApiFetchContentsResponse(res=res)
def _create_reload(self, req: ApiCreateReloadRequest):
if self.active_reload_id is not None:
return ApiCreateReloadResponseError(status='alreadyReloading')
filepath = (Path.cwd() / req.filepath).resolve()
if not filepath.is_relative_to(Path.cwd()) or (code_file := self.registry.get(str(filepath))) is None:
return ApiCreateReloadResponseError(status='fileNotFound')
Path(req.filepath).write_text(req.contents)
activity = EventSource(save_history=True) # TODO: Generic EventSource[ApiGetReloadEventSourceMessage]
code_file.activity = activity
thread = create_thread(self.run_reload, code_file)
reload_id = str(uuid4()) if req.reloadId is None else req.reloadId
self.reload_runs[reload_id] = ReloadRun(activity, thread)
self.active_reload_id = reload_id
thread.start()
return ApiCreateReloadResponseSuccess(status='created', reloadId=reload_id)
def create_reload(self, req: ApiCreateReloadRequest) -> ApiCreateReloadResponse:
"""
POST /create-reload
"""
with self.create_lock:
res = self._create_reload(req)
return ApiCreateReloadResponse(res=res)
def get_reload(self, req: ApiGetReloadRequest):
"""
POST /get-reload
"""
reload_id = req.reloadId
if (run := self.reload_runs.get(reload_id, None)) is None:
raise HTTPException(404, f"{reload_id=} not found")
queue = Queue()
run.activity.register(queue.put, apply_history=True)
def activity_stream(queue: Queue):
while (item := queue.get()) is not None:
if isinstance(item, ApiGetReloadEventSourceData):
yield item
elif isinstance(item, CodeFileOperation):
if (op := serialize_code_operation(item)) is not None:
yield ApiGetReloadEventSourceData(data=op)
def event_source_stream():
for data in activity_stream(queue):
yield data.model_dump_json()
return EventSourceResponse(event_source_stream())
def run_reload(self, code_file: CodeFile):
updated = False
try:
self.prerun()
code_file.refresh()
updated = self.postrun()
except Exception as exc:
tb = format_traceback(exc)
op = ReloadOperationError(kind='error', traceback=tb)
data = ApiGetReloadEventSourceData(data=op)
code_file.activity.emit(data)
op = ReloadOperationUI(kind='ui', updated=updated)
data = ApiGetReloadEventSourceData(data=op)
code_file.activity.emit(data)
code_file.activity.emit(None)
self.active_reload_id = None
def run(self, port: int):
app = FastAPI()
app.include_router(self.router)
server = Server(Config(app, host='0.0.0.0', port=port, log_level='warning'))
server_thread = create_thread(server.run)
server_thread.start()
self.stop_event.wait()
server.should_exit = True
def serialize_code_operation(cf_operation: CodeFileOperation):
assert isinstance(defn := cf_operation.defn, (LineDefinition, GroupDefinition))
region = serialize_extent(defn.stashed)
if isinstance(cf_operation, ExceptionOperation):
if (exc := cf_operation.exc) is None: # pragma: no cover
exc = Exception("Unable to retrieve reload exception")
tb = format_traceback(exc)
return ReloadOperationException(kind='exception', region=region, traceback=tb)
if isinstance(defn, GroupDefinition):
otype = get_object_type(defn)
oname = defn.dotpath()
if isinstance(cf_operation, AddOperation):
kind = 'add'
elif isinstance(cf_operation, UpdateOperation):
if not isinstance(defn, FunctionDefinition):
return None
kind = 'update'
else:
kind = 'delete'
return ReloadOperationObject(kind=kind, region=region, objectType=otype, objectName=oname)
if isinstance(defn, LineDefinition):
if isinstance(cf_operation, DeleteOperation):
return None
return ReloadOperationRun(kind='run', region=region, codeLines=defn.text)
assert_type(defn, Never) # pragma: no cover
def get_object_type(defn: GroupDefinition):
if isinstance(defn, ClassDefinition):
return 'class'
if isinstance(defn, ModuleCode):
return 'module'
if isinstance(defn, FunctionDefinition):
return 'function'
return 'unknown' # pragma: no cover
def serialize_extent(extent: Extent):
return ReloadRegion(
startLine=extent.lineno,
startCol=extent.col_offset,
endLine=extent.end_lineno,
endCol=extent.end_col_offset,
)
def format_traceback(exc: Exception):
traces = traceback.format_exception(type(exc), exc, exc.__traceback__)
return "".join(traces)