ddi / src /preprocessing /artifact_manager.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union
import hashlib
import pandas as pd
logger = logging.getLogger(__name__)
BASE_DIR = Path(__file__).resolve().parents[2]
DATA_DIR = BASE_DIR / 'data'
PROCESSED_DIR = DATA_DIR / 'processed'
RAW_DIR = DATA_DIR / 'raw'
CACHE_DIR = BASE_DIR / 'cache'
MANIFEST_PATH = PROCESSED_DIR / 'artifact_manifest.json'
def compute_checksum(file_path: Path) -> str:
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
class ArtifactManager:
"""Centralized Artifact Manager for structured datasets."""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.manifest = self._load_manifest()
self.initialized = True
def _load_manifest(self) -> Dict[str, Any]:
if MANIFEST_PATH.exists():
try:
with open(MANIFEST_PATH, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Could not load manifest: {e}")
return {"artifacts": {}, "version": "1.0"}
def _save_manifest(self):
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
with open(MANIFEST_PATH, 'w') as f:
json.dump(self.manifest, f, indent=2)
def register_artifact(self, name: str, df: pd.DataFrame, file_path: Path):
"""Register an artifact in the manifest with schema info."""
schema = {col: str(dtype) for col, dtype in df.dtypes.items()}
checksum = compute_checksum(file_path)
self.manifest['artifacts'][name] = {
'path': str(file_path.relative_to(BASE_DIR)),
'columns': list(df.columns),
'rows': len(df),
'schema': schema,
'checksum': checksum,
'version': "1.0"
}
self._save_manifest()
def load_artifact(self, name: str, required_columns: Optional[list] = None, validate_schema: bool = True) -> pd.DataFrame:
"""Load an artifact securely with validation."""
path_str = self.manifest.get('artifacts', {}).get(name, {}).get('path')
needs_rebuild = False
if not path_str:
path = PROCESSED_DIR / f"{name}.parquet"
if not path.exists():
needs_rebuild = True
else:
path = BASE_DIR / path_str
if not path.exists():
needs_rebuild = True
else:
expected_checksum = self.manifest['artifacts'][name].get('checksum')
if expected_checksum and compute_checksum(path) != expected_checksum:
logger.warning(f"Checksum mismatch for artifact {name}. Triggering rebuild...")
needs_rebuild = True
if needs_rebuild:
logger.info(f"Artifact {name} missing or invalid. Triggering rebuild...")
from preprocessing.artifact_store import ensure_structured_data
ensure_structured_data(force_rebuild=True)
self.manifest = self._load_manifest() # Reload manifest
# Update path resolution in case it changed
path_str = self.manifest.get('artifacts', {}).get(name, {}).get('path')
if path_str:
path = BASE_DIR / path_str
elif not path.exists():
path = PROCESSED_DIR / f"{name}.parquet"
if not path.exists():
raise FileNotFoundError(f"Failed to find or rebuild artifact: {name} at {path}")
df = pd.read_parquet(path)
if required_columns:
missing = [col for col in required_columns if col not in df.columns]
if missing:
raise ValueError(f"Artifact {name} missing required columns: {missing}")
if validate_schema and name in self.manifest['artifacts']:
expected_schema = self.manifest['artifacts'][name].get('schema', {})
for col, expected_type in expected_schema.items():
if col in df.columns:
actual_type = str(df[col].dtype)
if expected_type != actual_type and 'object' not in actual_type: # simple check
pass # Could warn here
return df
def verify_all_artifacts(self) -> bool:
"""Verify the integrity of all artifacts in the manifest."""
all_valid = True
for name, metadata in self.manifest.get('artifacts', {}).items():
path_str = metadata.get('path')
if not path_str:
all_valid = False
continue
path = BASE_DIR / path_str
if not path.exists():
all_valid = False
continue
if compute_checksum(path) != metadata.get('checksum'):
all_valid = False
continue
return all_valid
manager = ArtifactManager()