Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Download knowledge graph from a Hugging Face dataset directly into the final storage directory. | |
| No extra copy step, so no disk duplication. | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from dotenv import load_dotenv | |
| from huggingface_hub import HfApi | |
| load_dotenv(dotenv_path=".env", override=False) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| def default_storage_root() -> Path: | |
| if Path("/data").exists(): | |
| return Path("/data/rag_storage") | |
| return Path("data/rag_storage") | |
| def default_hf_home() -> Path: | |
| if Path("/data").exists(): | |
| return Path("/data/.huggingface") | |
| return Path("/tmp/.cache/huggingface") | |
| def download_knowledge_graph() -> bool: | |
| jurisdictions_str = os.getenv("JURISDICTIONS", "romania,bahrain") | |
| jurisdictions = [j.strip() for j in jurisdictions_str.split(",") if j.strip()] | |
| dataset_id = os.getenv( | |
| "HF_KNOWLEDGE_GRAPH_DATASET", | |
| "Cyberlgl/CyberLegalAI-knowledge-graph", | |
| ) | |
| hf_token = os.getenv("HF_TOKEN") | |
| hf_home = Path(os.getenv("HF_HOME", str(default_hf_home()))) | |
| target_base_dir = Path(os.getenv("LIGHTRAG_STORAGE_ROOT", str(default_storage_root()))) | |
| os.environ["HF_HOME"] = str(hf_home) | |
| hf_home.mkdir(parents=True, exist_ok=True) | |
| target_base_dir.mkdir(parents=True, exist_ok=True) | |
| api = HfApi(token=hf_token) | |
| repo_files = api.list_repo_files(repo_id=dataset_id, repo_type="dataset") | |
| logger.info(f"π Repo contains {len(repo_files)} files") | |
| for path in repo_files[:200]: | |
| logger.info(f" - {path}") | |
| allow_patterns = [] | |
| for jurisdiction in jurisdictions: | |
| allow_patterns.extend([ | |
| f"{jurisdiction}*", | |
| f"{jurisdiction}/*", | |
| f"{jurisdiction}/**", | |
| f"{jurisdiction}/**/*", | |
| f"*/{jurisdiction}/*", | |
| f"*/{jurisdiction}/**/*", | |
| ]) | |
| logger.info(f"π§© allow_patterns={allow_patterns}") | |
| logger.info("=" * 80) | |
| logger.info("π Starting Knowledge Graph Download") | |
| logger.info(f"π¦ Dataset: {dataset_id}") | |
| logger.info(f"π Jurisdictions: {', '.join(jurisdictions)}") | |
| logger.info(f"πΎ HF_HOME: {hf_home}") | |
| logger.info(f"π Final storage root: {target_base_dir}") | |
| logger.info("=" * 80) | |
| try: | |
| logger.info("π Analyzing dataset structure...") | |
| dry_run_info = snapshot_download( | |
| repo_id=dataset_id, | |
| # repo_type="dataset", | |
| allow_patterns=allow_patterns, | |
| token=hf_token, | |
| dry_run=True, | |
| ) | |
| to_download = [f for f in dry_run_info if getattr(f, "will_download", False)] | |
| total_bytes = sum(getattr(f, "size", 0) or 0 for f in to_download) | |
| logger.info( | |
| f"π Dry-run complete: {len(dry_run_info)} matching files, " | |
| f"{len(to_download)} to download, " | |
| f"{total_bytes / (1024 * 1024):.1f} MB total" | |
| ) | |
| # Log files per jurisdiction | |
| for jurisdiction in jurisdictions: | |
| jur_files = [f for f in to_download if jurisdiction in str(f)] | |
| jur_bytes = sum(getattr(f, "size", 0) or 0 for f in jur_files) | |
| logger.info( | |
| f" π¦ {jurisdiction}: {len(jur_files)} files, " | |
| f"{jur_bytes / (1024 * 1024):.1f} MB" | |
| ) | |
| logger.info("π Starting download to final directory...") | |
| logger.info(f" π Target: {target_base_dir}") | |
| logger.info(" β³ Downloading files (this may take several minutes)...") | |
| # Use tqdm for progress display if available | |
| try: | |
| import tqdm | |
| class ProgressCallback: | |
| def __init__(self, total_files): | |
| self.pbar = tqdm.tqdm( | |
| total=total_files, | |
| desc=" π₯ Downloading", | |
| unit="file", | |
| bar_format="{desc}: {percentage:3.0f}%|{bar:30}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" | |
| ) | |
| def update(self): | |
| self.pbar.update(1) | |
| def close(self): | |
| self.pbar.close() | |
| progress_callback = ProgressCallback(len(to_download)) | |
| downloaded_count = [0] | |
| # Custom logging during download | |
| original_download = snapshot_download | |
| import huggingface_hub | |
| # Intercept file downloads for logging | |
| logger.info(" π Progress tracking enabled") | |
| except ImportError: | |
| logger.info(" βΉοΈ tqdm not available, using basic logging") | |
| progress_callback = None | |
| snapshot_download( | |
| repo_id=dataset_id, | |
| repo_type="dataset", | |
| allow_patterns=allow_patterns, | |
| local_dir=str(target_base_dir), | |
| token=hf_token, | |
| ) | |
| if progress_callback: | |
| progress_callback.close() | |
| logger.info("β Download completed successfully") | |
| logger.info(" π¦ All files downloaded and verified") | |
| missing = [j for j in jurisdictions if not (target_base_dir / j).exists()] | |
| if missing: | |
| logger.error(f"β Missing jurisdiction folders after download: {missing}") | |
| return False | |
| for j in jurisdictions: | |
| jur_dir = target_base_dir / j | |
| size = sum(f.stat().st_size for f in jur_dir.rglob("*") if f.is_file()) | |
| logger.info(f"β {j}: {size / (1024 * 1024):.1f} MB ready in {jur_dir}") | |
| logger.info("=" * 80) | |
| logger.info("π Knowledge Graph Download Complete") | |
| logger.info("=" * 80) | |
| return True | |
| except Exception as e: | |
| logger.exception(f"β Error downloading knowledge graph: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| raise SystemExit(0 if download_knowledge_graph() else 1) |