CyberLegalAIendpoint / scripts /download_knowledge_graph.py
Charles Grandjean
pffff
8ec8ca7
#!/usr/bin/env python3
"""
Download knowledge graph from a Hugging Face dataset directly into the final storage directory.
No extra copy step, so no disk duplication.
"""
import os
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from dotenv import load_dotenv
from huggingface_hub import HfApi
load_dotenv(dotenv_path=".env", override=False)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def default_storage_root() -> Path:
if Path("/data").exists():
return Path("/data/rag_storage")
return Path("data/rag_storage")
def default_hf_home() -> Path:
if Path("/data").exists():
return Path("/data/.huggingface")
return Path("/tmp/.cache/huggingface")
def download_knowledge_graph() -> bool:
jurisdictions_str = os.getenv("JURISDICTIONS", "romania,bahrain")
jurisdictions = [j.strip() for j in jurisdictions_str.split(",") if j.strip()]
dataset_id = os.getenv(
"HF_KNOWLEDGE_GRAPH_DATASET",
"Cyberlgl/CyberLegalAI-knowledge-graph",
)
hf_token = os.getenv("HF_TOKEN")
hf_home = Path(os.getenv("HF_HOME", str(default_hf_home())))
target_base_dir = Path(os.getenv("LIGHTRAG_STORAGE_ROOT", str(default_storage_root())))
os.environ["HF_HOME"] = str(hf_home)
hf_home.mkdir(parents=True, exist_ok=True)
target_base_dir.mkdir(parents=True, exist_ok=True)
api = HfApi(token=hf_token)
repo_files = api.list_repo_files(repo_id=dataset_id, repo_type="dataset")
logger.info(f"πŸ“‚ Repo contains {len(repo_files)} files")
for path in repo_files[:200]:
logger.info(f" - {path}")
allow_patterns = []
for jurisdiction in jurisdictions:
allow_patterns.extend([
f"{jurisdiction}*",
f"{jurisdiction}/*",
f"{jurisdiction}/**",
f"{jurisdiction}/**/*",
f"*/{jurisdiction}/*",
f"*/{jurisdiction}/**/*",
])
logger.info(f"🧩 allow_patterns={allow_patterns}")
logger.info("=" * 80)
logger.info("πŸš€ Starting Knowledge Graph Download")
logger.info(f"πŸ“¦ Dataset: {dataset_id}")
logger.info(f"🌍 Jurisdictions: {', '.join(jurisdictions)}")
logger.info(f"πŸ’Ύ HF_HOME: {hf_home}")
logger.info(f"πŸ“ Final storage root: {target_base_dir}")
logger.info("=" * 80)
try:
logger.info("πŸ” Analyzing dataset structure...")
dry_run_info = snapshot_download(
repo_id=dataset_id,
# repo_type="dataset",
allow_patterns=allow_patterns,
token=hf_token,
dry_run=True,
)
to_download = [f for f in dry_run_info if getattr(f, "will_download", False)]
total_bytes = sum(getattr(f, "size", 0) or 0 for f in to_download)
logger.info(
f"πŸ“‹ Dry-run complete: {len(dry_run_info)} matching files, "
f"{len(to_download)} to download, "
f"{total_bytes / (1024 * 1024):.1f} MB total"
)
# Log files per jurisdiction
for jurisdiction in jurisdictions:
jur_files = [f for f in to_download if jurisdiction in str(f)]
jur_bytes = sum(getattr(f, "size", 0) or 0 for f in jur_files)
logger.info(
f" πŸ“¦ {jurisdiction}: {len(jur_files)} files, "
f"{jur_bytes / (1024 * 1024):.1f} MB"
)
logger.info("πŸš€ Starting download to final directory...")
logger.info(f" πŸ“ Target: {target_base_dir}")
logger.info(" ⏳ Downloading files (this may take several minutes)...")
# Use tqdm for progress display if available
try:
import tqdm
class ProgressCallback:
def __init__(self, total_files):
self.pbar = tqdm.tqdm(
total=total_files,
desc=" πŸ“₯ Downloading",
unit="file",
bar_format="{desc}: {percentage:3.0f}%|{bar:30}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"
)
def update(self):
self.pbar.update(1)
def close(self):
self.pbar.close()
progress_callback = ProgressCallback(len(to_download))
downloaded_count = [0]
# Custom logging during download
original_download = snapshot_download
import huggingface_hub
# Intercept file downloads for logging
logger.info(" πŸ“Š Progress tracking enabled")
except ImportError:
logger.info(" ℹ️ tqdm not available, using basic logging")
progress_callback = None
snapshot_download(
repo_id=dataset_id,
repo_type="dataset",
allow_patterns=allow_patterns,
local_dir=str(target_base_dir),
token=hf_token,
)
if progress_callback:
progress_callback.close()
logger.info("βœ… Download completed successfully")
logger.info(" πŸ“¦ All files downloaded and verified")
missing = [j for j in jurisdictions if not (target_base_dir / j).exists()]
if missing:
logger.error(f"❌ Missing jurisdiction folders after download: {missing}")
return False
for j in jurisdictions:
jur_dir = target_base_dir / j
size = sum(f.stat().st_size for f in jur_dir.rglob("*") if f.is_file())
logger.info(f"βœ… {j}: {size / (1024 * 1024):.1f} MB ready in {jur_dir}")
logger.info("=" * 80)
logger.info("πŸŽ‰ Knowledge Graph Download Complete")
logger.info("=" * 80)
return True
except Exception as e:
logger.exception(f"❌ Error downloading knowledge graph: {e}")
return False
if __name__ == "__main__":
raise SystemExit(0 if download_knowledge_graph() else 1)