File size: 6,147 Bytes
7414a53
 
929235c
 
7414a53
 
 
 
 
 
d262f27
 
7414a53
 
 
929235c
7414a53
 
 
929235c
 
 
 
 
 
 
 
 
 
 
 
 
7414a53
929235c
 
 
 
 
 
7414a53
929235c
 
 
 
 
 
 
 
d262f27
 
 
 
 
 
be4429e
 
 
 
d262f27
be4429e
 
 
d262f27
 
be4429e
 
 
929235c
7414a53
 
 
 
929235c
 
7414a53
929235c
7414a53
929235c
 
 
8ec8ca7
929235c
 
 
 
 
 
 
 
 
 
 
 
 
 
7414a53
929235c
 
 
 
 
7414a53
929235c
 
 
 
 
 
 
 
7414a53
929235c
 
 
 
 
 
 
 
 
 
 
 
 
 
7414a53
929235c
 
7414a53
929235c
 
 
7414a53
929235c
 
 
 
 
 
 
 
 
 
 
 
 
 
7414a53
929235c
 
7414a53
929235c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7414a53
 
929235c
7414a53
929235c
7414a53
 
 
 
929235c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3
"""
Download knowledge graph from a Hugging Face dataset directly into the final storage directory.
No extra copy step, so no disk duplication.
"""
import os
import logging
from pathlib import Path
from huggingface_hub import snapshot_download
from dotenv import load_dotenv
from huggingface_hub import HfApi


load_dotenv(dotenv_path=".env", override=False)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def default_storage_root() -> Path:
    if Path("/data").exists():
        return Path("/data/rag_storage")
    return Path("data/rag_storage")


def default_hf_home() -> Path:
    if Path("/data").exists():
        return Path("/data/.huggingface")
    return Path("/tmp/.cache/huggingface")


def download_knowledge_graph() -> bool:
    jurisdictions_str = os.getenv("JURISDICTIONS", "romania,bahrain")
    jurisdictions = [j.strip() for j in jurisdictions_str.split(",") if j.strip()]

    dataset_id = os.getenv(
        "HF_KNOWLEDGE_GRAPH_DATASET",
        "Cyberlgl/CyberLegalAI-knowledge-graph",
    )
    hf_token = os.getenv("HF_TOKEN")

    hf_home = Path(os.getenv("HF_HOME", str(default_hf_home())))
    target_base_dir = Path(os.getenv("LIGHTRAG_STORAGE_ROOT", str(default_storage_root())))

    os.environ["HF_HOME"] = str(hf_home)
    hf_home.mkdir(parents=True, exist_ok=True)
    target_base_dir.mkdir(parents=True, exist_ok=True)

    api = HfApi(token=hf_token)
    repo_files = api.list_repo_files(repo_id=dataset_id, repo_type="dataset")

    logger.info(f"πŸ“‚ Repo contains {len(repo_files)} files")
    for path in repo_files[:200]:
        logger.info(f"   - {path}")

    allow_patterns = []
    for jurisdiction in jurisdictions:
        allow_patterns.extend([
            f"{jurisdiction}*",
            f"{jurisdiction}/*",
            f"{jurisdiction}/**",
            f"{jurisdiction}/**/*",
            f"*/{jurisdiction}/*",
            f"*/{jurisdiction}/**/*",
        ])

    logger.info(f"🧩 allow_patterns={allow_patterns}")

    logger.info("=" * 80)
    logger.info("πŸš€ Starting Knowledge Graph Download")
    logger.info(f"πŸ“¦ Dataset: {dataset_id}")
    logger.info(f"🌍 Jurisdictions: {', '.join(jurisdictions)}")
    logger.info(f"πŸ’Ύ HF_HOME: {hf_home}")
    logger.info(f"πŸ“ Final storage root: {target_base_dir}")
    logger.info("=" * 80)

    try:
        logger.info("πŸ” Analyzing dataset structure...")
        dry_run_info = snapshot_download(
            repo_id=dataset_id,
            # repo_type="dataset",
            allow_patterns=allow_patterns,
            token=hf_token,
            dry_run=True,
        )
        to_download = [f for f in dry_run_info if getattr(f, "will_download", False)]
        total_bytes = sum(getattr(f, "size", 0) or 0 for f in to_download)

        logger.info(
            f"πŸ“‹ Dry-run complete: {len(dry_run_info)} matching files, "
            f"{len(to_download)} to download, "
            f"{total_bytes / (1024 * 1024):.1f} MB total"
        )

        # Log files per jurisdiction
        for jurisdiction in jurisdictions:
            jur_files = [f for f in to_download if jurisdiction in str(f)]
            jur_bytes = sum(getattr(f, "size", 0) or 0 for f in jur_files)
            logger.info(
                f"   πŸ“¦ {jurisdiction}: {len(jur_files)} files, "
                f"{jur_bytes / (1024 * 1024):.1f} MB"
            )

        logger.info("πŸš€ Starting download to final directory...")
        logger.info(f"   πŸ“ Target: {target_base_dir}")
        logger.info("   ⏳ Downloading files (this may take several minutes)...")
        
        # Use tqdm for progress display if available
        try:
            import tqdm
            
            class ProgressCallback:
                def __init__(self, total_files):
                    self.pbar = tqdm.tqdm(
                        total=total_files,
                        desc="   πŸ“₯ Downloading",
                        unit="file",
                        bar_format="{desc}: {percentage:3.0f}%|{bar:30}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"
                    )
                
                def update(self):
                    self.pbar.update(1)
                
                def close(self):
                    self.pbar.close()
            
            progress_callback = ProgressCallback(len(to_download))
            downloaded_count = [0]
            
            # Custom logging during download
            original_download = snapshot_download
            import huggingface_hub
            
            # Intercept file downloads for logging
            logger.info("   πŸ“Š Progress tracking enabled")
            
        except ImportError:
            logger.info("   ℹ️  tqdm not available, using basic logging")
            progress_callback = None

        snapshot_download(
            repo_id=dataset_id,
            repo_type="dataset",
            allow_patterns=allow_patterns,
            local_dir=str(target_base_dir),
            token=hf_token,
        )
        
        if progress_callback:
            progress_callback.close()
        
        logger.info("βœ… Download completed successfully")
        logger.info("   πŸ“¦ All files downloaded and verified")

        missing = [j for j in jurisdictions if not (target_base_dir / j).exists()]
        if missing:
            logger.error(f"❌ Missing jurisdiction folders after download: {missing}")
            return False

        for j in jurisdictions:
            jur_dir = target_base_dir / j
            size = sum(f.stat().st_size for f in jur_dir.rglob("*") if f.is_file())
            logger.info(f"βœ… {j}: {size / (1024 * 1024):.1f} MB ready in {jur_dir}")

        logger.info("=" * 80)
        logger.info("πŸŽ‰ Knowledge Graph Download Complete")
        logger.info("=" * 80)
        return True

    except Exception as e:
        logger.exception(f"❌ Error downloading knowledge graph: {e}")
        return False


if __name__ == "__main__":
    raise SystemExit(0 if download_knowledge_graph() else 1)