""" """ import traceback import threading import warnings from pathlib import Path from queue import Queue from typing import Any from typing import Callable from typing import NamedTuple from typing_extensions import assert_type from typing_extensions import Never from uuid import uuid4 from fastapi import APIRouter from fastapi import FastAPI from fastapi import HTTPException from uvicorn import Config from uvicorn import Server with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=FutureWarning, message=r'.*torch\.distributed\.reduce_op.*') from .._vendor.jurigged import codetools from .._vendor.jurigged.codetools import CodeFile from .._vendor.jurigged.codetools import CodeFileOperation from .._vendor.jurigged.codetools import AddOperation from .._vendor.jurigged.codetools import UpdateOperation from .._vendor.jurigged.codetools import DeleteOperation from .._vendor.jurigged.codetools import ExceptionOperation from .._vendor.jurigged.codetools import Extent from .._vendor.jurigged.codetools import LineDefinition from .._vendor.jurigged.codetools import GroupDefinition from .._vendor.jurigged.codetools import ModuleCode from .._vendor.jurigged.codetools import ClassDefinition from .._vendor.jurigged.codetools import FunctionDefinition from .._vendor.jurigged.register import EventSource from .._vendor.jurigged.register import Registry from .._vendor.jurigged.utils import glob_filter from .._vendor.sse_starlette import EventSourceResponse from ..utils import create_thread from .types import ApiCreateReloadRequest from .types import ApiCreateReloadResponse from .types import ApiCreateReloadResponseError from .types import ApiCreateReloadResponseSuccess from .types import ApiGetReloadRequest from .types import ApiGetReloadEventSourceData from .types import ApiGetStatusRequest from .types import ApiGetStatusResponse from .types import ApiFetchContentsRequest from .types import ApiFetchContentsResponse from .types import ApiFetchContentsResponseError from .types import ApiFetchContentsResponseSuccess from .types import ReloadRegion from .types import ReloadOperationError from .types import ReloadOperationException from .types import ReloadOperationObject from .types import ReloadOperationRun from .types import ReloadOperationUI class ReloadRun(NamedTuple): activity: EventSource thread: threading.Thread class ReloadServer: def __init__( self, prerun: Callable[[], Any], postrun: Callable[[], bool], stop_event: threading.Event, ): self.prerun = prerun self.postrun = postrun self.stop_event = stop_event self.create_lock = threading.Lock() self.reload_runs: dict[str, ReloadRun] = {} self.active_reload_id: str | None = None self.router = APIRouter() self.router.add_api_route('/healthz', lambda: {}, methods=['GET']) self.router.add_api_route('/get-status', self.get_status, methods=['POST']) self.router.add_api_route('/fetch-contents', self.fetch_contents, methods=['POST']) self.router.add_api_route('/create-reload', self.create_reload, methods=['POST']) self.router.add_api_route('/get-reload', self.get_reload, methods=['POST']) self.registry = Registry() self.registry.auto_register(filter=glob_filter('./*.py')) def get_status(self, req: ApiGetStatusRequest) -> ApiGetStatusResponse: """ POST /get-status """ raise NotImplementedError # pragma: no cover def fetch_contents(self, req: ApiFetchContentsRequest) -> ApiFetchContentsResponse: """ POST /fetch-contents """ filepath = (Path.cwd() / req.filepath).resolve() if not filepath.is_relative_to(Path.cwd()) or not (file := Path(req.filepath)).is_file(): res = ApiFetchContentsResponseError(status='fileNotFound') return ApiFetchContentsResponse(res=res) res = ApiFetchContentsResponseSuccess(status='ok', contents=file.read_text()) return ApiFetchContentsResponse(res=res) def _create_reload(self, req: ApiCreateReloadRequest): if self.active_reload_id is not None: return ApiCreateReloadResponseError(status='alreadyReloading') filepath = (Path.cwd() / req.filepath).resolve() if not filepath.is_relative_to(Path.cwd()) or (code_file := self.registry.get(str(filepath))) is None: return ApiCreateReloadResponseError(status='fileNotFound') Path(req.filepath).write_text(req.contents) activity = EventSource(save_history=True) # TODO: Generic EventSource[ApiGetReloadEventSourceMessage] code_file.activity = activity thread = create_thread(self.run_reload, code_file) reload_id = str(uuid4()) if req.reloadId is None else req.reloadId self.reload_runs[reload_id] = ReloadRun(activity, thread) self.active_reload_id = reload_id thread.start() return ApiCreateReloadResponseSuccess(status='created', reloadId=reload_id) def create_reload(self, req: ApiCreateReloadRequest) -> ApiCreateReloadResponse: """ POST /create-reload """ with self.create_lock: res = self._create_reload(req) return ApiCreateReloadResponse(res=res) def get_reload(self, req: ApiGetReloadRequest): """ POST /get-reload """ reload_id = req.reloadId if (run := self.reload_runs.get(reload_id, None)) is None: raise HTTPException(404, f"{reload_id=} not found") queue = Queue() run.activity.register(queue.put, apply_history=True) def activity_stream(queue: Queue): while (item := queue.get()) is not None: if isinstance(item, ApiGetReloadEventSourceData): yield item elif isinstance(item, CodeFileOperation): if (op := serialize_code_operation(item)) is not None: yield ApiGetReloadEventSourceData(data=op) def event_source_stream(): for data in activity_stream(queue): yield data.model_dump_json() return EventSourceResponse(event_source_stream()) def run_reload(self, code_file: CodeFile): updated = False try: self.prerun() code_file.refresh() updated = self.postrun() except Exception as exc: tb = format_traceback(exc) op = ReloadOperationError(kind='error', traceback=tb) data = ApiGetReloadEventSourceData(data=op) code_file.activity.emit(data) op = ReloadOperationUI(kind='ui', updated=updated) data = ApiGetReloadEventSourceData(data=op) code_file.activity.emit(data) code_file.activity.emit(None) self.active_reload_id = None def run(self, port: int): app = FastAPI() app.include_router(self.router) server = Server(Config(app, host='0.0.0.0', port=port, log_level='warning')) server_thread = create_thread(server.run) server_thread.start() self.stop_event.wait() server.should_exit = True def serialize_code_operation(cf_operation: CodeFileOperation): assert isinstance(defn := cf_operation.defn, (LineDefinition, GroupDefinition)) region = serialize_extent(defn.stashed) if isinstance(cf_operation, ExceptionOperation): if (exc := cf_operation.exc) is None: # pragma: no cover exc = Exception("Unable to retrieve reload exception") tb = format_traceback(exc) return ReloadOperationException(kind='exception', region=region, traceback=tb) if isinstance(defn, GroupDefinition): otype = get_object_type(defn) oname = defn.dotpath() if isinstance(cf_operation, AddOperation): kind = 'add' elif isinstance(cf_operation, UpdateOperation): if not isinstance(defn, FunctionDefinition): return None kind = 'update' else: kind = 'delete' return ReloadOperationObject(kind=kind, region=region, objectType=otype, objectName=oname) if isinstance(defn, LineDefinition): if isinstance(cf_operation, DeleteOperation): return None return ReloadOperationRun(kind='run', region=region, codeLines=defn.text) assert_type(defn, Never) # pragma: no cover def get_object_type(defn: GroupDefinition): if isinstance(defn, ClassDefinition): return 'class' if isinstance(defn, ModuleCode): return 'module' if isinstance(defn, FunctionDefinition): return 'function' return 'unknown' # pragma: no cover def serialize_extent(extent: Extent): return ReloadRegion( startLine=extent.lineno, startCol=extent.col_offset, endLine=extent.end_lineno, endCol=extent.end_col_offset, ) def format_traceback(exc: Exception): traces = traceback.format_exception(type(exc), exc, exc.__traceback__) return "".join(traces)