doc-translator / babeldoc /progress_monitor.py
Taha Mahmood
Initial upload
754d92a
import asyncio
import logging
import threading
import time
from asyncio import CancelledError
from collections.abc import Callable
from typing import Optional
logger = logging.getLogger(__name__)
class ProgressMonitor:
def __init__(
self,
stages: list[tuple[str, float]],
progress_change_callback: Callable | None = None,
finish_callback: Callable | None = None,
report_interval: float = 0.1,
finish_event: asyncio.Event | None = None,
cancel_event: threading.Event | None = None,
loop: asyncio.AbstractEventLoop | None = None,
parent_monitor: Optional["ProgressMonitor"] = None,
part_index: int | None = 0,
total_parts: int | None = 1,
):
self.lock = threading.Lock()
self.parent_monitor = parent_monitor
self.part_index = part_index
self.total_parts = total_parts
self.raw_stages = stages
self.part_results = {}
# Convert stages list to dict with name and weight
self.stage = {}
total_weight = sum(weight for _, weight in stages)
for name, weight in stages:
normalized_weight = weight / total_weight
self.stage[name] = TranslationStage(
name,
0,
self,
normalized_weight,
self.lock,
)
self.progress_change_callback = progress_change_callback
self.finish_callback = finish_callback
self.report_interval = report_interval
logger.debug(f"report_interval: {self.report_interval}")
self.last_report_time = 0
self.finish_stage_count = 0
self.finish_event = finish_event
self.cancel_event = cancel_event
self.loop = loop
self.disable = False
if finish_event and not loop:
raise ValueError("finish_event requires a loop")
if self.progress_change_callback:
self.progress_change_callback(
type="stage_summary",
stages=[
{
"name": name,
"percent": self.stage[name].weight,
}
for name, _ in stages
],
part_index=self.part_index,
total_parts=self.total_parts,
)
def create_part_monitor(
self, part_index: int, total_parts: int
) -> "ProgressMonitor":
"""Create a new progress monitor for a document part"""
return ProgressMonitor(
stages=self.raw_stages,
progress_change_callback=self._handle_part_progress,
finish_callback=self._handle_part_finish,
report_interval=self.report_interval,
cancel_event=self.cancel_event,
loop=self.loop,
parent_monitor=self,
part_index=part_index,
total_parts=total_parts,
)
def _handle_part_progress(self, **kwargs):
"""Handle progress updates from part monitors"""
if self.progress_change_callback and not self.disable:
# Add part information to progress update
kwargs["part_index"] = kwargs.get("part_index")
kwargs["total_parts"] = kwargs.get("total_parts")
self.progress_change_callback(**kwargs)
def _handle_part_finish(self, **kwargs):
"""Handle completion of a part translation"""
if kwargs["type"] == "error":
logger.info(f"progress_monitor handle_part_finish: {kwargs['error']}")
self.finish_callback(type="error", error=kwargs["error"])
return
if "translate_result" in kwargs:
part_index = kwargs.get("part_index")
if part_index is not None:
self.part_results[part_index] = kwargs["translate_result"]
# if self.finish_callback and not self.disable:
# self.finish_callback(**kwargs)
def stage_start(self, stage_name: str, total: int):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return DummyTranslationStage(stage_name, total, self, 0)
stage = self.stage[stage_name]
stage.run_time += 1
stage.name = stage_name
stage.display_name = f"{stage_name}" if stage.run_time > 1 else stage_name
stage.current = 0
stage.total = total
if self.progress_change_callback:
self.progress_change_callback(
type="progress_start",
stage=stage.display_name,
stage_progress=0.0,
stage_current=0,
stage_total=total,
overall_progress=self.calculate_current_progress(),
part_index=self.part_index + 1,
total_parts=self.total_parts,
)
self.last_report_time = 0.0
return stage
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug("ProgressMonitor __exit__")
def on_finish(self):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
if self.cancel_event:
self.cancel_event.set()
if self.finish_event and self.loop:
self.loop.call_soon_threadsafe(self.finish_event.set)
if self.cancel_event and self.cancel_event.is_set():
self.finish_callback(type="error", error=CancelledError)
def stage_done(self, stage):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
self.last_report_time = 0.0
self.finish_stage_count += 1
if (
stage.current != stage.total
and self.cancel_event is not None
and not self.cancel_event.is_set()
):
logger.warning(
f"Stage {stage.name} completed with {stage.current}/{stage.total} items",
)
return
if self.progress_change_callback:
self.progress_change_callback(
type="progress_end",
stage=stage.display_name,
stage_progress=100.0,
stage_current=stage.total,
stage_total=stage.total,
overall_progress=self.calculate_current_progress(),
part_index=self.part_index + 1,
total_parts=self.total_parts,
)
def calculate_current_progress(self, stage=None):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return 100
part_weight = 1 / self.total_parts
if self.parent_monitor:
part_offset = self.part_index * part_weight
else:
part_offset = len(self.part_results) * part_weight
part_offset *= 100
progress = self._calculate_current_progress(stage) * part_weight + part_offset
return progress
def _calculate_current_progress(self, stage=None):
"""Calculate overall progress including part progress"""
# Count completed stages
completed_stages = sum(
1 for s in self.stage.values() if s.run_time > 0 and s.current == s.total
)
# If all stages are complete, return exactly 100
if completed_stages == len(self.stage):
return 100
# Calculate progress based on weights
progress = sum(
s.weight * 100
for s in self.stage.values()
if s.run_time > 0 and s.current == s.total
)
if stage is not None and 0 < stage.total != stage.current:
progress += stage.weight * stage.current * 100 / stage.total
# If this is a part monitor (has parent_monitor), return the progress as is
if hasattr(self, "parent_monitor") and self.parent_monitor:
return progress
# Otherwise return the standard progress
return progress
def stage_update(self, stage, n: int):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
report_time_delta = time.time() - self.last_report_time
if report_time_delta < self.report_interval and stage.total > 3:
return
if self.progress_change_callback:
if stage.total != 0:
stage_progress = stage.current * 100 / stage.total
else:
stage_progress = 100
self.progress_change_callback(
type="progress_update",
stage=stage.display_name,
stage_progress=stage_progress,
stage_current=stage.current,
stage_total=stage.total,
overall_progress=self.calculate_current_progress(stage),
part_index=self.part_index + 1,
total_parts=self.total_parts,
)
self.last_report_time = time.time()
def translate_done(self, translate_result):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
if self.finish_callback:
self.finish_callback(type="finish", translate_result=translate_result)
def translate_error(self, error):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
if self.finish_callback:
logger.info(f"progress_monitor handle translate_error: {error}")
self.finish_callback(type="error", error=error)
def raise_if_cancelled(self):
if self.cancel_event and self.cancel_event.is_set():
raise asyncio.CancelledError
def cancel(self):
if self.disable or self.parent_monitor and self.parent_monitor.disable:
return
if self.cancel_event:
logger.info("Translation canceled")
self.cancel_event.set()
class TranslationStage:
def __init__(
self,
name: str,
total: int,
pm: ProgressMonitor,
weight: float,
lock: threading.Lock,
):
self.name = name
self.display_name = name
self.current = 0
self.total = total
self.pm = pm
self.run_time = 0
self.weight = weight
self.lock = lock
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.lock:
diff = self.total - self.current
if diff > 0:
logger.info(
f"Stage {self.name} completed with {self.current}/{self.total} items"
)
self.pm.stage_update(self, diff)
self.current = self.total
self.pm.stage_done(self)
def advance(self, n: int = 1):
with self.lock:
self.current += n
self.pm.stage_update(self, n)
class DummyTranslationStage:
def __init__(self, name: str, total: int, pm: ProgressMonitor, weight: float):
self.name = name
self.display_name = name
self.current = 0
self.total = total
self.pm = pm
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def advance(self, n: int = 1):
pass