samchun-gemini / checkpoints.py
JHyeok5's picture
Upload folder using huggingface_hub
c481aa6 verified
"""
์ฒดํฌํฌ์ธํŠธ ์‹œ์Šคํ…œ
์žฅ์‹œ๊ฐ„ ์ž‘์—… ์ค‘๋‹จ ํ›„ ๋งˆ์ง€๋ง‰ ์ง€์ ์—์„œ ์žฌ๊ฐœ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•จ
"""
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Dict, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class CheckpointManager:
"""์ฒดํฌํฌ์ธํŠธ ๊ด€๋ฆฌ์ž - ์ž‘์—… ์ง„ํ–‰ ์ƒํ™ฉ ์ €์žฅ ๋ฐ ๋ณต๊ตฌ"""
def __init__(self, checkpoint_dir: str = None):
if checkpoint_dir is None:
checkpoint_dir = os.getenv("CHECKPOINT_DIR", "/tmp/checkpoints")
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
def create_checkpoint(
self,
command: str,
user_input: str,
metadata: Dict[str, Any]
) -> str:
"""์ƒˆ ์ฒดํฌํฌ์ธํŠธ ์ƒ์„ฑ"""
checkpoint_id = f"ckpt_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{command}"
checkpoint = {
"id": checkpoint_id,
"command": command,
"status": "in_progress",
"created_at": datetime.now().isoformat() + "Z",
"updated_at": datetime.now().isoformat() + "Z",
"metadata": {
"initiated_by": "user",
"user_input": user_input,
**metadata
},
"progress": {
"total_items": 0,
"completed_items": 0,
"failed_items": 0,
"current_progress_percent": 0
},
"completed_items": [],
"failed_items": [],
"current_state": {},
"next_steps": [],
"error_log": [],
"recovery_info": {},
"workflow_specific": {}
}
self.save_checkpoint(checkpoint_id, checkpoint)
return checkpoint_id
def save_checkpoint(self, checkpoint_id: str, data: Dict[str, Any]):
"""์ฒดํฌํฌ์ธํŠธ ์ €์žฅ"""
checkpoint_path = self.checkpoint_dir / f"{checkpoint_id}.json"
data["updated_at"] = datetime.now().isoformat() + "Z"
with open(checkpoint_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def load_checkpoint(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
"""์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ"""
checkpoint_path = self.checkpoint_dir / f"{checkpoint_id}.json"
if not checkpoint_path.exists():
return None
with open(checkpoint_path, 'r', encoding='utf-8') as f:
return json.load(f)
def update_progress(
self,
checkpoint_id: str,
completed_item: Optional[str] = None,
failed_item: Optional[str] = None,
error: Optional[str] = None,
current_state: Optional[Dict] = None,
next_steps: Optional[list] = None,
workflow_specific: Optional[Dict] = None
):
"""์ง„ํ–‰ ์ƒํ™ฉ ์—…๋ฐ์ดํŠธ"""
checkpoint = self.load_checkpoint(checkpoint_id)
if not checkpoint:
return False
# completed_items ์—…๋ฐ์ดํŠธ
if completed_item:
checkpoint["completed_items"].append(completed_item)
checkpoint["progress"]["completed_items"] += 1
# failed_items ์—…๋ฐ์ดํŠธ
if failed_item:
checkpoint["failed_items"].append({
"item_id": failed_item,
"error": error,
"timestamp": datetime.now().isoformat() + "Z"
})
checkpoint["progress"]["failed_items"] += 1
# error_log์—๋„ ์ถ”๊ฐ€
checkpoint["error_log"].append({
"timestamp": datetime.now().isoformat() + "Z",
"error": error,
"context": f"Processing {failed_item}"
})
# ์ง„ํ–‰๋ฅ  ๊ณ„์‚ฐ
total = checkpoint["progress"]["total_items"]
if total > 0:
completed = checkpoint["progress"]["completed_items"]
checkpoint["progress"]["current_progress_percent"] = int(
(completed / total) * 100
)
# ์ƒํƒœ ์—…๋ฐ์ดํŠธ
if current_state:
checkpoint["current_state"] = current_state
if next_steps:
checkpoint["next_steps"] = next_steps
if workflow_specific:
checkpoint["workflow_specific"].update(workflow_specific)
self.save_checkpoint(checkpoint_id, checkpoint)
return True
def mark_completed(self, checkpoint_id: str):
"""์ž‘์—… ์™„๋ฃŒ ํ‘œ์‹œ"""
checkpoint = self.load_checkpoint(checkpoint_id)
if checkpoint:
checkpoint["status"] = "completed"
checkpoint["recovery_info"]["can_resume"] = False
self.save_checkpoint(checkpoint_id, checkpoint)
def mark_paused(
self,
checkpoint_id: str,
reason: str,
recovery_instructions: str
):
"""์ž‘์—… ์ผ์‹œ ์ค‘์ง€ ํ‘œ์‹œ"""
checkpoint = self.load_checkpoint(checkpoint_id)
if checkpoint:
checkpoint["status"] = "paused"
checkpoint["recovery_info"]["pause_reason"] = reason
checkpoint["recovery_info"]["recovery_instructions"] = recovery_instructions
checkpoint["recovery_info"]["can_resume"] = True
self.save_checkpoint(checkpoint_id, checkpoint)
def mark_failed(self, checkpoint_id: str, reason: str):
"""์ž‘์—… ์‹คํŒจ ํ‘œ์‹œ"""
checkpoint = self.load_checkpoint(checkpoint_id)
if checkpoint:
checkpoint["status"] = "failed"
checkpoint["recovery_info"]["failure_reason"] = reason
checkpoint["recovery_info"]["can_resume"] = False
self.save_checkpoint(checkpoint_id, checkpoint)
def get_recovery_plan(self, checkpoint_id: str) -> Optional[Dict[str, Any]]:
"""๋ณต๊ตฌ ๊ณ„ํš ์กฐํšŒ"""
checkpoint = self.load_checkpoint(checkpoint_id)
if not checkpoint or not checkpoint.get("recovery_info", {}).get("can_resume"):
return None
return {
"checkpoint_id": checkpoint_id,
"command": checkpoint["command"],
"resume_from": checkpoint["recovery_info"].get("resume_from", {}),
"next_steps": checkpoint["next_steps"],
"context": checkpoint["current_state"],
"instructions": checkpoint["recovery_info"].get(
"recovery_instructions",
"Resume from last checkpoint"
)
}
def list_checkpoints(self, command: Optional[str] = None, status: Optional[str] = None) -> list:
"""์ฒดํฌํฌ์ธํŠธ ๋ชฉ๋ก ์กฐํšŒ"""
checkpoints = []
for checkpoint_file in self.checkpoint_dir.glob("*.json"):
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
checkpoint = json.load(f)
# ํ•„ํ„ฐ๋ง
if command is not None and checkpoint["command"] != command:
continue
if status is not None and checkpoint["status"] != status:
continue
checkpoints.append({
"id": checkpoint["id"],
"command": checkpoint["command"],
"status": checkpoint["status"],
"created_at": checkpoint["created_at"],
"updated_at": checkpoint["updated_at"],
"progress": checkpoint["progress"],
"can_resume": checkpoint.get("recovery_info", {}).get("can_resume", False)
})
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Skipping invalid checkpoint file {checkpoint_file}: {e}")
continue
return sorted(checkpoints, key=lambda x: x["created_at"], reverse=True)
def cleanup_old_checkpoints(self, days: int = 7):
"""์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์ •๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: 7์ผ)"""
from datetime import timezone
# timezone-aware cutoff_date ์‚ฌ์šฉ
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
deleted_count = 0
for checkpoint_file in self.checkpoint_dir.glob("*.json"):
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
checkpoint = json.load(f)
# timezone-aware datetime์œผ๋กœ ํŒŒ์‹ฑ
created = datetime.fromisoformat(checkpoint["created_at"].replace("Z", "+00:00"))
if checkpoint["status"] in ["completed", "failed"] and created < cutoff_date:
checkpoint_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning(f"Error processing checkpoint file {checkpoint_file}: {e}")
continue
return deleted_count
def check_and_resume_from_checkpoint(self, command: str) -> Optional[Dict[str, Any]]:
"""์ด์ „ ์ฒดํฌํฌ์ธํŠธ์—์„œ ์žฌ๊ฐœ ๊ฐ€๋Šฅํ•œ์ง€ ํ™•์ธ"""
checkpoints = self.list_checkpoints(command=command)
for ckpt in checkpoints:
if ckpt["can_resume"] and ckpt["status"] in ["paused", "in_progress"]:
recovery_plan = self.get_recovery_plan(ckpt["id"])
if recovery_plan:
return recovery_plan
return None