PinkSky / server /process_manager.py
FreshPixels's picture
Rename process_manager.py to server/process_manager.py
90388cb verified
Raw
History Blame Contribute Delete
3.19 kB
"""Управление потоками и отмена генерации"""
import threading
import uuid
import logging
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor
class ProcessManager:
def __init__(self):
self.active_threads: List[threading.Thread] = []
self.cancel_flags: Dict[str, bool] = {}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=10)
self.futures = []
self.logger = logging.getLogger(__name__)
def register_thread(self, thread: threading.Thread) -> None:
with self.lock:
self.active_threads.append(thread)
# Генерируем уникальный ID для потока
thread_id = str(uuid.uuid4())
self.cancel_flags[thread_id] = False
# Сохраняем ID в атрибуте потока для доступа
thread._cancel_id = thread_id
def register_future(self, future) -> None:
with self.lock:
self.futures.append(future)
def cancel_all(self) -> str:
with self.lock:
for future in self.futures:
if not future.done():
future.cancel()
self.futures.clear()
cancelled_threads = []
for thread in self.active_threads:
thread_id = getattr(thread, "_cancel_id", None)
if thread_id is not None:
self.cancel_flags[thread_id] = True
cancelled_threads.append(thread.name)
if thread.is_alive():
try:
thread.join(timeout=0.5)
except Exception as e:
self.logger.warning(f"Failed to join thread {thread.name}: {e}")
self.active_threads.clear()
self.cancel_flags.clear()
try:
from interpreter import interpreter
if hasattr(interpreter, 'cancel'):
interpreter.cancel()
except Exception as e:
self.logger.error(f"Failed to cancel interpreter: {e}")
from .state import STATE
STATE.cancel_flag = True
STATE.current_mode = "paused"
return (
"╔════════════════════╗\n"
"║ ✅ ᴀʟʟ ᴘʀᴏᴄᴇssᴇs sᴛᴏᴘᴘᴇᴅ ║\n"
"║ Memory flushed ║\n"
"║ Threads terminated ║\n"
"╚════════════════════╝"
)
def is_cancelled(self, thread_id: str = None) -> bool:
if thread_id is None:
current_thread = threading.current_thread()
thread_id = getattr(current_thread, "_cancel_id", None)
with self.lock:
return self.cancel_flags.get(thread_id, False)
def clear(self) -> None:
with self.lock:
self.active_threads = [t for t in self.active_threads if t.is_alive()]
def get_active_count(self) -> int:
with self.lock:
return len([t for t in self.active_threads if t.is_alive()])