FictionAgent / utils /cache_manager.py
gdwind's picture
Upload folder using huggingface_hub
a226682 verified
import os
import json
import hashlib
import pickle
from typing import Any, Optional, Dict # 添加 Dict
from pathlib import Path
from config import Config
class CacheManager:
"""缓存管理器"""
def __init__(self, cache_dir: str = None):
self.cache_dir = cache_dir or Config.CACHE_DIR
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
def _get_cache_key(self, *args, **kwargs) -> str:
"""生成缓存键"""
content = str(args) + str(sorted(kwargs.items()))
return hashlib.md5(content.encode()).hexdigest()
def _get_cache_path(self, key: str) -> str:
"""获取缓存文件路径"""
return os.path.join(self.cache_dir, f"{key}.pkl")
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if not Config.ENABLE_CACHE:
return None
cache_file = self._get_cache_path(key)
if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f:
return pickle.load(f)
except Exception as e:
print(f"读取缓存失败 ({key}): {e}")
return None
return None
def set(self, key: str, value: Any) -> bool:
"""设置缓存"""
if not Config.ENABLE_CACHE:
return False
cache_file = self._get_cache_path(key)
try:
with open(cache_file, 'wb') as f:
pickle.dump(value, f)
return True
except Exception as e:
print(f"缓存保存失败 ({key}): {e}")
return False
def exists(self, key: str) -> bool:
"""检查缓存是否存在"""
if not Config.ENABLE_CACHE:
return False
return os.path.exists(self._get_cache_path(key))
def delete(self, key: str) -> bool:
"""删除指定缓存"""
cache_file = self._get_cache_path(key)
if os.path.exists(cache_file):
try:
os.remove(cache_file)
return True
except Exception as e:
print(f"删除缓存失败 ({key}): {e}")
return False
return False
def clear(self, pattern: str = None):
"""清除缓存"""
if not os.path.exists(self.cache_dir):
return
count = 0
for file in os.listdir(self.cache_dir):
if file.endswith('.pkl'):
if pattern is None or pattern in file:
try:
os.remove(os.path.join(self.cache_dir, file))
count += 1
except Exception as e:
print(f"删除缓存文件失败 ({file}): {e}")
print(f"已清除 {count} 个缓存文件")
def get_cache_size(self) -> int:
"""获取缓存总大小(字节)"""
if not os.path.exists(self.cache_dir):
return 0
total_size = 0
for file in os.listdir(self.cache_dir):
if file.endswith('.pkl'):
file_path = os.path.join(self.cache_dir, file)
total_size += os.path.getsize(file_path)
return total_size
def get_cache_info(self) -> dict:
"""获取缓存信息"""
if not os.path.exists(self.cache_dir):
return {
'count': 0,
'size': 0,
'size_mb': 0
}
count = 0
total_size = 0
for file in os.listdir(self.cache_dir):
if file.endswith('.pkl'):
count += 1
file_path = os.path.join(self.cache_dir, file)
total_size += os.path.getsize(file_path)
return {
'count': count,
'size': total_size,
'size_mb': round(total_size / (1024 * 1024), 2)
}
def save_json(self, key: str, data: dict) -> bool:
"""保存为JSON格式(用于可读性)"""
cache_file = os.path.join(self.cache_dir, f"{key}.json")
try:
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
print(f"保存JSON缓存失败 ({key}): {e}")
return False
def load_json(self, key: str) -> Optional[dict]:
"""加载JSON格式缓存"""
cache_file = os.path.join(self.cache_dir, f"{key}.json")
if os.path.exists(cache_file):
try:
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
print(f"读取JSON缓存失败 ({key}): {e}")
return None
return None