Shield Agents
🛡️ Initial release - Shield Agents v1.0.0
de31cf7
Raw
History Blame Contribute Delete
7.29 kB
"""
Caching and Incremental Scanning for Shield Agents.
Cache previous scan results and only re-scan changed files.
This is the difference between a demo and a production tool.
"""
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from .utils.crypto import hash_file
logger = logging.getLogger("shield_agents.cache")
class ScanCache:
"""Manages scan result caching for incremental scanning.
The cache stores:
- File hashes (to detect changes)
- Scan results per file
- Timestamps (for cache invalidation)
- Global metadata
"""
def __init__(self, cache_dir: str = ".shield-cache", max_age_days: int = 30):
"""Initialize the scan cache.
Args:
cache_dir: Directory for cache storage.
max_age_days: Maximum age of cache entries in days.
"""
self.cache_dir = Path(cache_dir)
self.max_age_days = max_age_days
self.cache_file = self.cache_dir / "scan_cache.json"
self._cache: Dict[str, Any] = {}
self._loaded = False
self.changed = False
def _ensure_loaded(self) -> None:
"""Load cache from disk if not already loaded."""
if not self._loaded:
self._load()
self._loaded = True
def _load(self) -> None:
"""Load cache from disk."""
if self.cache_file.exists():
try:
with open(self.cache_file, "r") as f:
self._cache = json.load(f)
logger.debug(f"Loaded cache with {len(self._cache.get('files', {}))} entries")
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load cache: {e}")
self._cache = {"files": {}, "metadata": {}}
else:
self._cache = {"files": {}, "metadata": {}}
def save(self) -> None:
"""Save cache to disk."""
if not self.changed:
return
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._cache["metadata"]["last_saved"] = time.time()
try:
with open(self.cache_file, "w") as f:
json.dump(self._cache, f, indent=2)
self.changed = False
logger.debug("Cache saved successfully")
except IOError as e:
logger.warning(f"Failed to save cache: {e}")
def get_file_hash(self, file_path: str) -> Optional[str]:
"""Get the cached hash for a file.
Args:
file_path: Path to the file.
Returns:
Cached hash, or None if not cached.
"""
self._ensure_loaded()
return self._cache.get("files", {}).get(file_path, {}).get("hash")
def compute_file_hash(self, file_path: str) -> Optional[str]:
"""Compute current hash for a file.
Args:
file_path: Path to the file.
Returns:
Current file hash, or None if file cannot be read.
"""
return hash_file(file_path)
def is_file_changed(self, file_path: str) -> bool:
"""Check if a file has changed since last scan.
Args:
file_path: Path to the file.
Returns:
True if file is new or modified, False if unchanged.
"""
self._ensure_loaded()
current_hash = self.compute_file_hash(file_path)
cached_hash = self.get_file_hash(file_path)
if current_hash is None:
# Can't read the file, treat as unchanged
return False
if cached_hash is None:
# Not cached, treat as new
return True
return current_hash != cached_hash
def get_changed_files(self, file_paths: List[str]) -> List[str]:
"""Get list of files that have changed since last scan.
Args:
file_paths: List of file paths to check.
Returns:
List of changed file paths.
"""
return [fp for fp in file_paths if self.is_file_changed(fp)]
def get_unchanged_files(self, file_paths: List[str]) -> List[str]:
"""Get list of files that haven't changed since last scan.
Args:
file_paths: List of file paths to check.
Returns:
List of unchanged file paths.
"""
return [fp for fp in file_paths if not self.is_file_changed(fp)]
def get_cached_findings(self, file_path: str) -> List[Dict[str, Any]]:
"""Get cached scan findings for a file.
Args:
file_path: Path to the file.
Returns:
List of cached findings, or empty list if not cached.
"""
self._ensure_loaded()
file_entry = self._cache.get("files", {}).get(file_path, {})
return file_entry.get("findings", [])
def update_file(self, file_path: str, findings: List[Dict[str, Any]]) -> None:
"""Update cache for a file with new findings.
Args:
file_path: Path to the file.
findings: List of findings for this file.
"""
self._ensure_loaded()
current_hash = self.compute_file_hash(file_path)
if "files" not in self._cache:
self._cache["files"] = {}
self._cache["files"][file_path] = {
"hash": current_hash,
"findings": findings,
"last_scanned": time.time(),
"finding_count": len(findings),
}
self.changed = True
def invalidate_file(self, file_path: str) -> None:
"""Remove a file from the cache.
Args:
file_path: Path to the file.
"""
self._ensure_loaded()
if file_path in self._cache.get("files", {}):
del self._cache["files"][file_path]
self.changed = True
def clear(self) -> None:
"""Clear the entire cache."""
self._cache = {"files": {}, "metadata": {}}
self.changed = True
def cleanup_expired(self) -> int:
"""Remove expired cache entries.
Returns:
Number of expired entries removed.
"""
self._ensure_loaded()
if not self.max_age_days:
return 0
cutoff = time.time() - (self.max_age_days * 86400)
expired = []
for file_path, entry in self._cache.get("files", {}).items():
if entry.get("last_scanned", 0) < cutoff:
expired.append(file_path)
for file_path in expired:
del self._cache["files"][file_path]
if expired:
self.changed = True
logger.info(f"Cleaned up {len(expired)} expired cache entries")
return len(expired)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Dictionary with cache statistics.
"""
self._ensure_loaded()
files = self._cache.get("files", {})
total_findings = sum(e.get("finding_count", 0) for e in files.values())
return {
"cached_files": len(files),
"total_cached_findings": total_findings,
"cache_size_bytes": self.cache_file.stat().st_size if self.cache_file.exists() else 0,
"last_saved": self._cache.get("metadata", {}).get("last_saved"),
}