oneblackmage commited on
Commit
7849935
·
verified ·
1 Parent(s): ac9bb45

Upload folder using huggingface_hub

Browse files
utils/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared AI utilities for training and dataset pipelines.
3
+ """
4
+
5
+ from .ngc_cli import (
6
+ NGCCLI,
7
+ NGCCLIAuthError,
8
+ NGCCLIDownloadError,
9
+ NGCCLIError,
10
+ NGCCLINotFoundError,
11
+ NGCConfig,
12
+ ensure_ngc_cli_configured,
13
+ get_ngc_cli,
14
+ )
15
+
16
+ __all__ = [
17
+ "NGCCLI",
18
+ "NGCCLIAuthError",
19
+ "NGCCLIDownloadError",
20
+ "NGCCLIError",
21
+ "NGCCLINotFoundError",
22
+ "NGCConfig",
23
+ "ensure_ngc_cli_configured",
24
+ "get_ngc_cli",
25
+ ]
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (584 Bytes). View file
 
utils/__pycache__/llm_capabilities.cpython-311.pyc ADDED
Binary file (3.35 kB). View file
 
utils/__pycache__/ngc_cli.cpython-311.pyc ADDED
Binary file (17.8 kB). View file
 
utils/__pycache__/ngc_resources.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
utils/__pycache__/s3_dataset_loader.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
utils/__pycache__/transcript_corrector.cpython-311.pyc ADDED
Binary file (7.8 kB). View file
 
utils/llm_capabilities.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from google import genai
4
+
5
+ _WORKING_MODEL_CACHE = None
6
+
7
+
8
+ def get_best_available_gemini_model(client: genai.Client) -> str:
9
+ """
10
+ Dynamically interrogates the Gemini API to find the best functioning
11
+ model available for the current API key's tier/region. This prevents
12
+ hardcoded models from throwing 404s if they are restricted.
13
+ """
14
+ global _WORKING_MODEL_CACHE
15
+ if _WORKING_MODEL_CACHE:
16
+ return _WORKING_MODEL_CACHE
17
+
18
+ target_models = [
19
+ "models/gemini-2.0-flash-001",
20
+ "models/gemini-2.0-flash-lite-001",
21
+ "models/gemini-flash-latest",
22
+ "models/gemini-pro-latest",
23
+ "models/gemini-2.5-flash",
24
+ "models/gemini-2.5-pro",
25
+ ]
26
+
27
+ try:
28
+ available_models = [m.name for m in client.models.list()]
29
+ print(f"DISCOVERED MODELS on this key: {available_models}")
30
+ except Exception as e:
31
+ print(f"Failed to list models: {e}")
32
+ return "gemini-1.5-flash" # Fallback guess
33
+
34
+ for target in target_models:
35
+ for available in available_models:
36
+ if target == available or available.endswith(target):
37
+ # Double check that we can actually invoke it
38
+ # (some show up in list but 404 on invoke due to constraints)
39
+ try:
40
+ client.models.generate_content(model=target, contents="ping")
41
+ _WORKING_MODEL_CACHE = target
42
+ print(f"Dynamically locked to functioning Gemini model: {target}")
43
+ return target
44
+ except Exception as eval_e:
45
+ print(f"Model {target} is listed but uninvokeable: {eval_e}")
46
+ continue
47
+
48
+ print(
49
+ "CRITICAL WARNING: No preferred Gemini models available on this API Key. "
50
+ "Falling back to gemini-flash-latest."
51
+ )
52
+ return "models/gemini-flash-latest"
53
+
54
+
55
+ def ensure_valid_key() -> str:
56
+ """Validates that the Gemini API key provided is a REST key, not an OAuth token."""
57
+ key = os.environ.get("GOOGLE_CLOUD_API_KEY") or os.environ.get("GEMINI_API_KEY")
58
+ if not key:
59
+ raise ValueError(
60
+ "Neither GOOGLE_CLOUD_API_KEY nor GEMINI_API_KEY are configured."
61
+ )
62
+ if key.startswith("AQ"):
63
+ raise ValueError(
64
+ "Provided GEMINI_API_KEY is an OAuth token (AQ...). "
65
+ "The AI engine requires a Google Cloud REST API key (AIza...). "
66
+ "Please update your .env file."
67
+ )
68
+ return key
utils/ngc_cli.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NGC CLI Utility Module
3
+
4
+ Provides utilities for working with NVIDIA GPU Cloud (NGC) CLI to download
5
+ NeMo resources, datasets, and other NGC catalog resources.
6
+
7
+ This module handles:
8
+ - NGC CLI detection and installation
9
+ - Resource download from NGC catalog
10
+ - Configuration management
11
+ - Error handling and retry logic
12
+ """
13
+
14
+ import logging
15
+ import os
16
+ import shutil
17
+ import subprocess
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class NGCConfig:
27
+ """NGC CLI configuration"""
28
+
29
+ api_key: str | None = None
30
+ org: str | None = None
31
+ team: str | None = None
32
+
33
+
34
+ class NGCCLIError(Exception):
35
+ """Base exception for NGC CLI operations"""
36
+
37
+
38
+ class NGCCLINotFoundError(NGCCLIError):
39
+ """NGC CLI not found or not installed"""
40
+
41
+
42
+ class NGCCLIAuthError(NGCCLIError):
43
+ """NGC CLI authentication error"""
44
+
45
+
46
+ class NGCCLIDownloadError(NGCCLIError):
47
+ """NGC CLI download error"""
48
+
49
+
50
+ class NGCCLI:
51
+ """
52
+ NGC CLI wrapper for downloading resources from NVIDIA GPU Cloud.
53
+
54
+ Supports multiple installation methods:
55
+ 1. System-installed ngc in PATH
56
+ 2. Local installation at ~/ngc-cli/ngc
57
+ 3. Python package via uv (ngc-python-cli)
58
+ """
59
+
60
+ def __init__(self, use_uv: bool = True):
61
+ """
62
+ Initialize NGC CLI wrapper.
63
+
64
+ Args:
65
+ use_uv: If True, prefer uv-based installation if ngc not in PATH
66
+ """
67
+ self.use_uv = use_uv
68
+ self.ngc_cmd: str | None = None
69
+ self.uv_cmd: str | None = None
70
+ self._detect_ngc_cli()
71
+
72
+ def _detect_ngc_cli(self) -> None:
73
+ """Detect and set up NGC CLI command"""
74
+ # Method 1: Check if ngc is in PATH
75
+ if shutil.which("ngc"):
76
+ self.ngc_cmd = "ngc"
77
+ logger.info("Found NGC CLI in PATH")
78
+ return
79
+
80
+ # Method 2: Check common installation location
81
+ home_ngc = Path.home() / "ngc-cli" / "ngc"
82
+ if home_ngc.exists():
83
+ self.ngc_cmd = str(home_ngc)
84
+ # Add to PATH for subprocess calls
85
+ env_path = os.environ.get("PATH", "")
86
+ os.environ["PATH"] = f"{home_ngc.parent}:{env_path}"
87
+ logger.info(f"Found NGC CLI at {home_ngc}")
88
+ return
89
+
90
+ # Method 3: Use uv to run ngc (if enabled)
91
+ if self.use_uv:
92
+ self._setup_uv_ngc()
93
+
94
+ def _setup_uv_ngc(self) -> None:
95
+ """Set up NGC CLI via uv"""
96
+ # Find uv
97
+ if shutil.which("uv"):
98
+ self.uv_cmd = "uv"
99
+ elif (Path.home() / ".local" / "bin" / "uv").exists():
100
+ self.uv_cmd = str(Path.home() / ".local" / "bin" / "uv")
101
+ elif (Path.home() / ".cargo" / "bin" / "uv").exists():
102
+ self.uv_cmd = str(Path.home() / ".cargo" / "bin" / "uv")
103
+ else:
104
+ logger.warning("uv not found, cannot use uv-based NGC CLI")
105
+ return
106
+
107
+ # Check if ngc is installed via uv
108
+ try:
109
+ result = subprocess.run(
110
+ [self.uv_cmd, "pip", "list"],
111
+ capture_output=True,
112
+ text=True,
113
+ check=False,
114
+ )
115
+ if "ngc" in result.stdout.lower():
116
+ self.ngc_cmd = f"{self.uv_cmd} run ngc"
117
+ logger.info("Found NGC CLI via uv")
118
+ return
119
+ except Exception as e:
120
+ logger.debug(f"Error checking uv packages: {e}")
121
+
122
+ # Note: NGC CLI is not a Python package on PyPI
123
+ # It must be downloaded from https://catalog.ngc.nvidia.com
124
+ # We can only check if it's available in PATH or local installation
125
+ # The uv method here is for running Python-based NGC SDK if available
126
+ logger.debug("NGC CLI must be installed separately from NVIDIA website")
127
+
128
+ def is_available(self) -> bool:
129
+ """Check if NGC CLI is available"""
130
+ return self.ngc_cmd is not None
131
+
132
+ def ensure_available(self) -> None:
133
+ """Ensure NGC CLI is available, raise error if not"""
134
+ if not self.is_available():
135
+ raise NGCCLINotFoundError(
136
+ "NGC CLI not found. Please install it:\n"
137
+ " 1. Download from https://catalog.ngc.nvidia.com\n"
138
+ " 2. Or install to ~/ngc-cli/ directory\n"
139
+ " 3. Or add to system PATH\n"
140
+ "\n"
141
+ "Note: NGC CLI is not available as a PyPI package.\n"
142
+ "You must download it directly from NVIDIA."
143
+ )
144
+
145
+ def check_config(self) -> dict[str, Any]:
146
+ """
147
+ Check NGC CLI configuration.
148
+
149
+ Returns:
150
+ Configuration dictionary with API key status, org, team, etc.
151
+
152
+ Raises:
153
+ NGCCLINotFoundError: If NGC CLI is not available
154
+ NGCCLIAuthError: If authentication is not configured
155
+ """
156
+ self.ensure_available()
157
+
158
+ if self.ngc_cmd is None:
159
+ raise NGCCLINotFoundError("NGC CLI command not set")
160
+ try:
161
+ result = subprocess.run(
162
+ [*self.ngc_cmd.split(), "config", "current"],
163
+ capture_output=True,
164
+ text=True,
165
+ check=True,
166
+ )
167
+
168
+ config = {}
169
+ # Parse the table format output
170
+ lines = result.stdout.strip().split("\n")
171
+ current_key = None
172
+
173
+ for line in lines:
174
+ if "|" in line and "| key " not in line.lower() and "---" not in line:
175
+ parts = [part.strip() for part in line.split("|") if part.strip()]
176
+ if len(parts) >= 3: # key | value | source
177
+ key, value, source = parts[0], parts[1], parts[2]
178
+ if key: # New key
179
+ current_key = key
180
+ config[key] = value
181
+ elif current_key and value: # Continuation of previous key
182
+ config[current_key] += value
183
+ elif len(parts) == 1 and current_key: # Just a value continuation
184
+ config[current_key] += parts[0]
185
+
186
+ # Check if API key is configured (it will be masked with asterisks)
187
+ # If we have any config and apikey exists (even masked), consider it configured
188
+ if config and ("apikey" in config or "API key" in config):
189
+ return config
190
+
191
+ raise NGCCLIAuthError(
192
+ "NGC CLI not configured. Run: ngc config set\n"
193
+ "Get your API key from: https://catalog.ngc.nvidia.com"
194
+ )
195
+
196
+ return config
197
+ except subprocess.CalledProcessError as e:
198
+ raise NGCCLIAuthError(f"Failed to check NGC config: {e.stderr}") from e
199
+
200
+ def set_config(
201
+ self, api_key: str, _org: str | None = None, _team: str | None = None
202
+ ) -> None:
203
+ """
204
+ Configure NGC CLI with API key.
205
+
206
+ Args:
207
+ api_key: NGC API key from https://catalog.ngc.nvidia.com
208
+ _org: Optional organization name (reserved for future use)
209
+ _team: Optional team name (reserved for future use)
210
+ """
211
+ self.ensure_available()
212
+
213
+ if self.ngc_cmd is None:
214
+ raise NGCCLINotFoundError("NGC CLI command not set")
215
+ # Set API key
216
+ try:
217
+ subprocess.run(
218
+ [*self.ngc_cmd.split(), "config", "set"],
219
+ input=f"{api_key}\n",
220
+ text=True,
221
+ check=True,
222
+ capture_output=True,
223
+ )
224
+ logger.info("NGC CLI configured successfully")
225
+ except subprocess.CalledProcessError as e:
226
+ raise NGCCLIAuthError(f"Failed to configure NGC CLI: {e.stderr}") from e
227
+
228
+ def download_resource(
229
+ self,
230
+ resource_path: str,
231
+ version: str | None = None,
232
+ output_dir: Path | None = None,
233
+ extract: bool = True, # noqa: ARG002
234
+ ) -> Path:
235
+ """
236
+ Download a resource from NGC catalog.
237
+
238
+ Args:
239
+ resource_path: Resource path in format "org/team/resource" or "nvidia/nemo-microservices/nemo-microservices-quickstart"
240
+ version: Optional version tag (e.g., "25.10")
241
+ output_dir: Optional output directory (defaults to current directory)
242
+ extract: Whether to extract downloaded archive
243
+
244
+ Returns:
245
+ Path to downloaded/extracted resource
246
+
247
+ Raises:
248
+ NGCCLINotFoundError: If NGC CLI is not available
249
+ NGCCLIAuthError: If authentication failed
250
+ NGCCLIDownloadError: If download failed
251
+ """
252
+ self.ensure_available()
253
+
254
+ # Check config first
255
+ try:
256
+ self.check_config()
257
+ except NGCCLIAuthError:
258
+ logger.warning("NGC CLI not configured. Attempting download anyway...")
259
+
260
+ if output_dir is None:
261
+ output_dir = Path.cwd()
262
+ else:
263
+ output_dir = Path(output_dir)
264
+ output_dir.mkdir(parents=True, exist_ok=True)
265
+
266
+ if self.ngc_cmd is None:
267
+ raise NGCCLINotFoundError("NGC CLI command not set")
268
+ # Build download command
269
+ cmd = [*self.ngc_cmd.split(), "registry", "resource", "download-version"]
270
+
271
+ resource_spec = f"{resource_path}:{version}" if version else resource_path
272
+
273
+ cmd.append(resource_spec)
274
+
275
+ # Change to output directory for download
276
+ original_cwd = Path.cwd()
277
+ try:
278
+ return self._execute_download_in_directory(output_dir, resource_spec, cmd)
279
+ finally:
280
+ os.chdir(original_cwd)
281
+
282
+ def _execute_download_in_directory(
283
+ self, output_dir: Path, resource_spec: str, cmd: list[str]
284
+ ) -> Path:
285
+ """
286
+ Execute download command in the specified directory and locate the downloaded resource.
287
+
288
+ Args:
289
+ output_dir: Directory to download into
290
+ resource_spec: Resource specification string for logging
291
+ cmd: Command to execute
292
+
293
+ Returns:
294
+ Path to the downloaded resource (most recently modified item, or output_dir if empty)
295
+
296
+ Raises:
297
+ NGCCLIDownloadError: If download fails
298
+ """
299
+ os.chdir(output_dir)
300
+ logger.info(f"Downloading {resource_spec} to {output_dir}...")
301
+
302
+ result = subprocess.run(cmd, capture_output=True, text=True, check=False)
303
+
304
+ if result.returncode != 0:
305
+ error_msg = result.stderr or result.stdout
306
+ raise NGCCLIDownloadError(
307
+ f"Failed to download {resource_spec}:\n{error_msg}"
308
+ )
309
+
310
+ logger.info(f"Successfully downloaded {resource_spec}")
311
+
312
+ if downloaded_items := list(output_dir.iterdir()):
313
+ # Return the most recently modified item
314
+ return max(downloaded_items, key=lambda p: p.stat().st_mtime)
315
+
316
+ return output_dir
317
+
318
+ def list_resources(
319
+ self, org: str | None = None, team: str | None = None
320
+ ) -> list[dict[str, Any]]:
321
+ """
322
+ List available resources in NGC catalog.
323
+
324
+ Args:
325
+ org: Optional organization filter
326
+ team: Optional team filter
327
+
328
+ Returns:
329
+ List of resource dictionaries
330
+ """
331
+ self.ensure_available()
332
+
333
+ if self.ngc_cmd is None:
334
+ raise NGCCLINotFoundError("NGC CLI command not set")
335
+ cmd = [*self.ngc_cmd.split(), "registry", "resource", "list"]
336
+
337
+ if org:
338
+ cmd.extend(["--org", org])
339
+ if team:
340
+ cmd.extend(["--team", team])
341
+
342
+ try:
343
+ subprocess.run(cmd, capture_output=True, text=True, check=True)
344
+
345
+ # Parse output (format may vary)
346
+ # TODO: Implement proper parsing based on actual NGC CLI output format
347
+ return []
348
+ except subprocess.CalledProcessError as e:
349
+ logger.warning(f"Failed to list resources: {e.stderr}")
350
+ return []
351
+
352
+
353
+ def get_ngc_cli(use_uv: bool = True) -> NGCCLI:
354
+ """
355
+ Get an NGC CLI instance.
356
+
357
+ Args:
358
+ use_uv: If True, prefer uv-based installation
359
+
360
+ Returns:
361
+ NGCCLI instance
362
+ """
363
+ return NGCCLI(use_uv=use_uv)
364
+
365
+
366
+ def ensure_ngc_cli_configured(api_key: str | None = None) -> NGCCLI:
367
+ """
368
+ Ensure NGC CLI is available and configured.
369
+
370
+ Args:
371
+ api_key: Optional API key to configure (if not already configured)
372
+
373
+ Returns:
374
+ Configured NGCCLI instance
375
+
376
+ Raises:
377
+ NGCCLINotFoundError: If NGC CLI cannot be found or installed
378
+ NGCCLIAuthError: If configuration fails
379
+ """
380
+ cli = get_ngc_cli()
381
+
382
+ if not cli.is_available():
383
+ raise NGCCLINotFoundError(
384
+ "NGC CLI not available. Install with:\n"
385
+ " uv pip install nvidia-pyindex nvidia-nim ngc-python-cli\n"
386
+ "Or download from: https://catalog.ngc.nvidia.com"
387
+ )
388
+
389
+ # Check if already configured
390
+ try:
391
+ cli.check_config()
392
+ return cli
393
+ except NGCCLIAuthError as err:
394
+ if api_key:
395
+ cli.set_config(api_key)
396
+ return cli
397
+ raise NGCCLIAuthError(
398
+ "NGC CLI not configured. Provide API key or run: ngc config set\n"
399
+ "Get API key from: https://catalog.ngc.nvidia.com"
400
+ ) from err
utils/ngc_resources.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NGC Resources Downloader for Training Ready
3
+
4
+ Downloads NeMo resources and training-related assets from NGC catalog.
5
+ Integrates with training_ready pipeline for automated resource acquisition.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Add parent directory to path for imports
14
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
15
+
16
+ from ai.utils.ngc_cli import (
17
+ NGCCLIAuthError,
18
+ NGCCLINotFoundError,
19
+ ensure_ngc_cli_configured,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class NGCResourceDownloader:
26
+ """
27
+ Downloads NeMo and training resources from NGC catalog.
28
+ """
29
+
30
+ # Common NeMo resources used in training
31
+ NEMO_RESOURCES = {
32
+ "nemo-microservices-quickstart": {
33
+ "path": "nvidia/nemo-microservices/nemo-microservices-quickstart",
34
+ "default_version": "25.10",
35
+ "description": "NeMo Microservices quickstart package",
36
+ },
37
+ "nemo-framework": {
38
+ "path": "nvidia/nemo/nemo",
39
+ "default_version": None, # Use latest
40
+ "description": "NeMo framework for training",
41
+ },
42
+ "nemo-megatron": {
43
+ "path": "nvidia/nemo/nemo-megatron",
44
+ "default_version": None,
45
+ "description": "NeMo Megatron for large-scale training",
46
+ },
47
+ }
48
+
49
+ def __init__(self, output_base: Path | None = None, api_key: str | None = None):
50
+ """
51
+ Initialize NGC resource downloader.
52
+
53
+ Args:
54
+ output_base: Base directory for downloads (defaults to training_ready/resources/)
55
+ api_key: Optional NGC API key (if not set, will check environment or prompt)
56
+ """
57
+ if output_base is None:
58
+ output_base = Path(__file__).parent.parent / "resources"
59
+ self.output_base = Path(output_base)
60
+ self.output_base.mkdir(parents=True, exist_ok=True)
61
+
62
+ # Get API key from environment if not provided
63
+ if api_key is None:
64
+ api_key = os.environ.get("NGC_API_KEY")
65
+
66
+ try:
67
+ self.cli = ensure_ngc_cli_configured(api_key=api_key)
68
+ except (NGCCLINotFoundError, NGCCLIAuthError) as e:
69
+ logger.warning(f"NGC CLI not available: {e}")
70
+
71
+ self.cli = None
72
+
73
+ def download_nemo_quickstart(
74
+ self, version: str | None = None, output_dir: Path | None = None
75
+ ) -> Path:
76
+ """
77
+ Download NeMo Microservices quickstart package.
78
+
79
+ Args:
80
+ version: Version to download (defaults to 25.10)
81
+ output_dir: Output directory (defaults to resources/nemo-microservices/)
82
+
83
+ Returns:
84
+ Path to downloaded/extracted quickstart directory
85
+ """
86
+ if not self.cli:
87
+ raise NGCCLINotFoundError("NGC CLI not available")
88
+
89
+ if version is None:
90
+ version = self.NEMO_RESOURCES["nemo-microservices-quickstart"][
91
+ "default_version"
92
+ ]
93
+
94
+ if output_dir is None:
95
+ output_dir = self.output_base / "nemo-microservices"
96
+
97
+ resource_path = self.NEMO_RESOURCES["nemo-microservices-quickstart"]["path"]
98
+
99
+ logger.info(f"Downloading NeMo Microservices quickstart v{version}...")
100
+ return self.cli.download_resource(
101
+ resource_path=resource_path,
102
+ version=version,
103
+ output_dir=output_dir,
104
+ extract=True,
105
+ )
106
+
107
+ def download_nemo_framework(
108
+ self, version: str | None = None, output_dir: Path | None = None
109
+ ) -> Path:
110
+ """
111
+ Download NeMo framework.
112
+
113
+ Args:
114
+ version: Version to download
115
+ output_dir: Output directory
116
+
117
+ Returns:
118
+ Path to downloaded framework
119
+ """
120
+ if not self.cli:
121
+ raise NGCCLINotFoundError("NGC CLI not available")
122
+
123
+ if output_dir is None:
124
+ output_dir = self.output_base / "nemo-framework"
125
+
126
+ resource_path = self.NEMO_RESOURCES["nemo-framework"]["path"]
127
+
128
+ logger.info("Downloading NeMo framework...")
129
+ return self.cli.download_resource(
130
+ resource_path=resource_path,
131
+ version=version,
132
+ output_dir=output_dir,
133
+ extract=True,
134
+ )
135
+
136
+ def download_custom_resource(
137
+ self,
138
+ resource_path: str,
139
+ version: str | None = None,
140
+ output_dir: Path | None = None,
141
+ ) -> Path:
142
+ """
143
+ Download a custom resource from NGC catalog.
144
+
145
+ Args:
146
+ resource_path: Resource path (e.g., "nvidia/nemo-microservices/nemo-microservices-quickstart")
147
+ version: Optional version tag
148
+ output_dir: Optional output directory
149
+
150
+ Returns:
151
+ Path to downloaded resource
152
+ """
153
+ if not self.cli:
154
+ raise NGCCLINotFoundError("NGC CLI not available")
155
+
156
+ if output_dir is None:
157
+ # Create directory from resource name
158
+ resource_name = resource_path.split("/")[-1]
159
+ output_dir = self.output_base / resource_name
160
+
161
+ logger.info(f"Downloading {resource_path}...")
162
+ return self.cli.download_resource(
163
+ resource_path=resource_path,
164
+ version=version,
165
+ output_dir=output_dir,
166
+ extract=True,
167
+ )
168
+
169
+
170
+ def download_nemo_quickstart(
171
+ version: str | None = None, output_dir: Path | None = None
172
+ ) -> Path:
173
+ """
174
+ Convenience function to download NeMo Microservices quickstart.
175
+
176
+ Args:
177
+ version: Version to download (defaults to 25.10)
178
+ output_dir: Output directory
179
+
180
+ Returns:
181
+ Path to downloaded quickstart directory
182
+ """
183
+ downloader = NGCResourceDownloader()
184
+ return downloader.download_nemo_quickstart(version=version, output_dir=output_dir)
utils/s3_dataset_loader.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ S3 Dataset Loader - Streaming JSON/JSONL loader for S3 training data
4
+ S3 is the training mecca - all training data should be loaded from S3
5
+ """
6
+
7
+ import contextlib
8
+ import json
9
+ import logging
10
+ import os
11
+ from collections.abc import Iterator
12
+ from io import BytesIO
13
+ from pathlib import Path
14
+ from typing import TYPE_CHECKING, Any
15
+
16
+ try:
17
+ import boto3
18
+ from botocore.exceptions import ClientError as _BotocoreClientError
19
+ except ImportError:
20
+ # Keep runtime behavior (error on use) while making type checkers happy.
21
+ boto3 = None # type: ignore[assignment]
22
+ _BotocoreClientError = None # type: ignore[assignment]
23
+
24
+ if TYPE_CHECKING:
25
+ # Minimal shape we rely on in this module.
26
+ class ClientError(Exception):
27
+ response: dict[str, Any]
28
+ else:
29
+ ClientError = (
30
+ _BotocoreClientError if _BotocoreClientError is not None else Exception
31
+ ) # type: ignore[assignment]
32
+
33
+ BOTO3_AVAILABLE = boto3 is not None
34
+
35
+ # Load .env file if available
36
+ with contextlib.suppress(ImportError):
37
+ from dotenv import load_dotenv
38
+
39
+ # Try loading from ai/ directory first (where .env actually is), then project root
40
+ # Module is at: ai/training_ready/utils/s3_dataset_loader.py
41
+ # So parents[0] = ai/training_ready/utils/, parents[1] = ai/training_ready/,
42
+ # parents[2] = ai/, parents[3] = project root
43
+ module_path = Path(__file__).resolve()
44
+ env_paths = []
45
+ try:
46
+ env_paths.append(module_path.parents[2] / ".env") # ai/.env
47
+ env_paths.append(module_path.parents[3] / ".env") # project root/.env
48
+ except IndexError:
49
+ # Fallback for shallow/flattened structures
50
+ env_paths.append(module_path.parent / ".env")
51
+ if module_path.parent.name != "ai":
52
+ env_paths.append(module_path.parent.parent / ".env")
53
+
54
+ for env_path in env_paths:
55
+ try:
56
+ if env_path.exists() and env_path.is_file():
57
+ load_dotenv(env_path, override=False)
58
+ break
59
+ except Exception:
60
+ continue
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+
65
+ class S3DatasetLoader:
66
+ """
67
+ Load datasets from S3 with streaming support for large files.
68
+ S3 is the canonical training data location.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ bucket: str = "pixel-data",
74
+ endpoint_url: str | None = None,
75
+ aws_access_key_id: str | None = None,
76
+ aws_secret_access_key: str | None = None,
77
+ region_name: str = "us-east-va",
78
+ ):
79
+ """
80
+ Initialize S3 client for dataset loading.
81
+
82
+ Args:
83
+ bucket: S3 bucket name (default: pixel-data)
84
+ endpoint_url: S3 endpoint URL (default: OVH S3 endpoint)
85
+ aws_access_key_id: AWS access key (from env if not provided)
86
+ aws_secret_access_key: AWS secret key (from env if not provided)
87
+ region_name: AWS region (default: us-east-va for OVH)
88
+ """
89
+ if boto3 is None:
90
+ raise ImportError(
91
+ "boto3 is required for S3 dataset loading. "
92
+ "Install with: uv pip install boto3"
93
+ )
94
+
95
+ # Always allow env to override bucket for OVH S3
96
+ # This ensures OVH_S3_BUCKET is always used when set
97
+ self.bucket = os.getenv("OVH_S3_BUCKET", bucket)
98
+ print(
99
+ f"[DEBUG] S3Loader: env OVH_S3_BUCKET={os.getenv('OVH_S3_BUCKET')}, "
100
+ f"input bucket={bucket}, final={self.bucket}",
101
+ flush=True,
102
+ )
103
+ self.endpoint_url = endpoint_url or os.getenv(
104
+ "OVH_S3_ENDPOINT", "https://s3.us-east-va.io.cloud.ovh.us"
105
+ )
106
+
107
+ # Get credentials from params or environment
108
+ access_key = (
109
+ aws_access_key_id
110
+ or os.getenv("OVH_S3_ACCESS_KEY")
111
+ or os.getenv("OVH_ACCESS_KEY")
112
+ or os.getenv("AWS_ACCESS_KEY_ID")
113
+ )
114
+ secret_key = (
115
+ aws_secret_access_key
116
+ or os.getenv("OVH_S3_SECRET_KEY")
117
+ or os.getenv("OVH_SECRET_KEY")
118
+ or os.getenv("AWS_SECRET_ACCESS_KEY")
119
+ )
120
+
121
+ if not access_key or not secret_key:
122
+ raise ValueError(
123
+ "S3 credentials not found. Set OVH_S3_ACCESS_KEY/OVH_S3_SECRET_KEY "
124
+ "(or OVH_ACCESS_KEY/OVH_SECRET_KEY, "
125
+ "or AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)."
126
+ )
127
+
128
+ # Initialize S3 client (OVH S3 compatible)
129
+ # OVH uses self-signed certificates, so verify=False is required
130
+ # Initialize S3 client (OVH S3 compatible)
131
+ # OVH uses self-signed certificates, so verify=False is required
132
+ verify_ssl = os.getenv("OVH_S3_CA_BUNDLE", True)
133
+ # Handle string "False" or "0" from env
134
+ if str(verify_ssl).lower() in {"false", "0", "no"}:
135
+ verify_ssl = False
136
+
137
+ if verify_ssl is False:
138
+ logger.warning(
139
+ "Initializing S3 client with SSL verification DISABLED (insecure)"
140
+ )
141
+
142
+ self.s3_client = boto3.client(
143
+ "s3",
144
+ endpoint_url=self.endpoint_url,
145
+ aws_access_key_id=access_key,
146
+ aws_secret_access_key=secret_key,
147
+ region_name=region_name or os.getenv("OVH_S3_REGION", "us-east-va"),
148
+ verify=verify_ssl,
149
+ )
150
+
151
+ logger.info(f"S3DatasetLoader initialized for bucket: {bucket}")
152
+
153
+ def _parse_s3_path(self, s3_path: str) -> tuple[str, str]:
154
+ """
155
+ Parse S3 path into bucket and key.
156
+
157
+ Args:
158
+ s3_path: S3 path (s3://bucket/key or just key)
159
+
160
+ Returns:
161
+ Tuple of (bucket, key)
162
+ """
163
+ # If path starts with s3://, it includes bucket
164
+ if s3_path.startswith("s3://"):
165
+ s3_path = s3_path.removeprefix("s3://")
166
+ if "/" in s3_path:
167
+ parts = s3_path.split("/", 1)
168
+ return parts[0], parts[1]
169
+ # s3://bucket-only (no key)
170
+ return s3_path, ""
171
+
172
+ # Otherwise, it's just a key - use configured bucket
173
+ return self.bucket, s3_path
174
+
175
+ def object_exists(self, s3_path: str) -> bool:
176
+ """Check if S3 object exists"""
177
+ try:
178
+ bucket, key = self._parse_s3_path(s3_path)
179
+ self.s3_client.head_object(Bucket=bucket, Key=key)
180
+ return True
181
+ except ClientError as e:
182
+ if e.response["Error"]["Code"] == "404":
183
+ return False
184
+ raise
185
+
186
+ def load_json(
187
+ self,
188
+ s3_path: str,
189
+ cache_local: Path | None = None,
190
+ ) -> dict[str, Any]:
191
+ """
192
+ Load JSON dataset from S3.
193
+
194
+ Args:
195
+ s3_path: S3 path (s3://bucket/key or just key)
196
+ cache_local: Optional local cache path
197
+
198
+ Returns:
199
+ Parsed JSON data
200
+ """
201
+ bucket, key = self._parse_s3_path(s3_path)
202
+
203
+ # Check local cache first
204
+ if cache_local and cache_local.exists():
205
+ logger.info(f"Loading from local cache: {cache_local}")
206
+ with open(cache_local) as f:
207
+ return json.load(f)
208
+
209
+ # Load from S3
210
+ logger.info(f"Loading from S3: s3://{bucket}/{key}")
211
+ try:
212
+ response = self.s3_client.get_object(Bucket=bucket, Key=key)
213
+ data = json.loads(response["Body"].read())
214
+
215
+ # Cache locally if requested
216
+ if cache_local:
217
+ cache_local.parent.mkdir(parents=True, exist_ok=True)
218
+ with open(cache_local, "w") as f:
219
+ json.dump(data, f)
220
+ logger.info(f"Cached to: {cache_local}")
221
+
222
+ return data
223
+ except ClientError as e:
224
+ if e.response["Error"]["Code"] == "NoSuchKey":
225
+ raise FileNotFoundError(
226
+ f"Dataset not found in S3: s3://{bucket}/{key}"
227
+ ) from e
228
+ raise
229
+
230
+ def load_bytes(self, s3_path: str) -> bytes:
231
+ """
232
+ Load raw bytes from S3.
233
+
234
+ Args:
235
+ s3_path: S3 path (s3://bucket/key or just key)
236
+
237
+ Returns:
238
+ Raw bytes of the object body
239
+ """
240
+ bucket, key = self._parse_s3_path(s3_path)
241
+ logger.info(f"Loading bytes from S3: s3://{bucket}/{key}")
242
+
243
+ try:
244
+ response = self.s3_client.get_object(Bucket=bucket, Key=key)
245
+ return response["Body"].read()
246
+ except ClientError as e:
247
+ if e.response["Error"]["Code"] == "NoSuchKey":
248
+ raise FileNotFoundError(
249
+ f"Dataset not found in S3: s3://{bucket}/{key}"
250
+ ) from e
251
+ raise
252
+
253
+ def load_text(
254
+ self,
255
+ s3_path: str,
256
+ *,
257
+ encoding: str = "utf-8",
258
+ errors: str = "replace",
259
+ ) -> str:
260
+ """
261
+ Load a text object from S3.
262
+
263
+ This is primarily for transcript corpora (e.g. .txt) that need to be
264
+ converted into ChatML examples.
265
+ """
266
+ data = self.load_bytes(s3_path)
267
+ return data.decode(encoding, errors=errors)
268
+
269
+ def _parse_jsonl_line(self, line: bytes) -> dict[str, Any] | None:
270
+ """
271
+ Parse a single JSONL line with robust error handling.
272
+
273
+ Args:
274
+ line: Raw bytes of a JSONL line
275
+
276
+ Returns:
277
+ Parsed JSON object or None if parsing failed
278
+ """
279
+ try:
280
+ return json.loads(line.decode("utf-8"))
281
+ except UnicodeDecodeError:
282
+ try:
283
+ return json.loads(line.decode("utf-8", errors="replace"))
284
+ except json.JSONDecodeError as e:
285
+ logger.warning(f"Failed to parse JSONL line: {e}")
286
+ except json.JSONDecodeError as e:
287
+ logger.warning(f"Failed to parse JSONL line: {e}")
288
+ return None
289
+
290
+ def _stream_with_iter_lines(self, body) -> Iterator[dict[str, Any]]:
291
+ """
292
+ Stream JSONL using iter_lines() method.
293
+
294
+ Args:
295
+ body: S3 response body with iter_lines capability
296
+
297
+ Yields:
298
+ Parsed JSON objects
299
+ """
300
+ for raw_line in body.iter_lines():
301
+ if not raw_line:
302
+ continue
303
+ parsed = self._parse_jsonl_line(raw_line)
304
+ if parsed is not None:
305
+ yield parsed
306
+
307
+ def _stream_with_manual_buffering(self, body) -> Iterator[dict[str, Any]]:
308
+ """
309
+ Stream JSONL using manual buffering as fallback.
310
+
311
+ Args:
312
+ body: S3 response body
313
+
314
+ Yields:
315
+ Parsed JSON objects
316
+ """
317
+ buffer = BytesIO()
318
+ for chunk in body.iter_chunks(chunk_size=8192):
319
+ buffer.write(chunk)
320
+ while True:
321
+ buffer.seek(0)
322
+ line = buffer.readline()
323
+ if not line:
324
+ buffer = BytesIO()
325
+ break
326
+ if not line.endswith(b"\n"):
327
+ # Keep incomplete tail in buffer
328
+ rest = buffer.read()
329
+ buffer = BytesIO(line + rest)
330
+ break
331
+
332
+ parsed = self._parse_jsonl_line(line)
333
+ if parsed is not None:
334
+ yield parsed
335
+
336
+ rest = buffer.read()
337
+ buffer = BytesIO(rest)
338
+
339
+ def stream_jsonl(self, s3_path: str) -> Iterator[dict[str, Any]]:
340
+ """
341
+ Stream JSONL dataset from S3 (memory-efficient for large files).
342
+
343
+ Args:
344
+ s3_path: S3 path (s3://bucket/key or just key)
345
+
346
+ Yields:
347
+ Parsed JSON objects (one per line)
348
+ """
349
+ bucket, key = self._parse_s3_path(s3_path)
350
+
351
+ logger.info(f"Streaming JSONL from S3: s3://{bucket}/{key}")
352
+ try:
353
+ response = self.s3_client.get_object(Bucket=bucket, Key=key)
354
+ body = response["Body"]
355
+
356
+ with contextlib.closing(body):
357
+ # Prefer iter_lines() which handles chunk boundaries robustly
358
+ iter_lines = getattr(body, "iter_lines", None)
359
+ if callable(iter_lines):
360
+ yield from self._stream_with_iter_lines(body)
361
+ else:
362
+ # Fallback to manual buffering
363
+ yield from self._stream_with_manual_buffering(body)
364
+
365
+ except ClientError as e:
366
+ if e.response["Error"]["Code"] == "NoSuchKey":
367
+ raise FileNotFoundError(
368
+ f"Dataset not found in S3: s3://{bucket}/{key}"
369
+ ) from e
370
+ raise
371
+
372
+ def list_datasets(self, prefix: str = "gdrive/processed/") -> list[str]:
373
+ """
374
+ List available datasets in S3.
375
+
376
+ Args:
377
+ prefix: S3 prefix to search (default: gdrive/processed/)
378
+
379
+ Returns:
380
+ List of S3 paths
381
+ """
382
+ logger.info(f"Listing datasets with prefix: {prefix}")
383
+ datasets: list[str] = []
384
+
385
+ try:
386
+ paginator = self.s3_client.get_paginator("list_objects_v2")
387
+ pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix)
388
+
389
+ for page in pages:
390
+ if "Contents" in page:
391
+ datasets.extend(
392
+ f"s3://{self.bucket}/{obj['Key']}"
393
+ for obj in page["Contents"]
394
+ if obj["Key"].endswith((".json", ".jsonl"))
395
+ )
396
+
397
+ except ClientError:
398
+ logger.exception("Failed to list S3 objects")
399
+ raise
400
+ return datasets
401
+
402
+ def download_file(self, s3_path: str, local_path: Path | str) -> None:
403
+ """Download a file from S3 to local path"""
404
+ try:
405
+ bucket, key = self._parse_s3_path(s3_path)
406
+ logger.info(f"Downloading s3://{bucket}/{key} to {local_path}")
407
+ self.s3_client.download_file(bucket, key, str(local_path))
408
+ except Exception:
409
+ logger.exception(f"Failed to download {s3_path} to {local_path}")
410
+ raise
411
+
412
+ def upload_file(self, local_path: Path | str, s3_key: str) -> None:
413
+ """Upload a local file to S3"""
414
+ try:
415
+ if not isinstance(local_path, Path):
416
+ local_path = Path(local_path)
417
+
418
+ bucket, key = self._parse_s3_path(s3_key)
419
+
420
+ logger.info(f"Uploading {local_path} to s3://{bucket}/{key}")
421
+ self.s3_client.upload_file(str(local_path), bucket, key)
422
+ except Exception:
423
+ logger.exception(f"Failed to upload {local_path} to {s3_key}")
424
+ raise
425
+
426
+
427
+ def get_s3_dataset_path(
428
+ dataset_name: str,
429
+ category: str | None = None,
430
+ bucket: str = "pixel-data",
431
+ prefer_processed: bool = True,
432
+ ) -> str:
433
+ """
434
+ Get S3 path for dataset - S3 is canonical training data location.
435
+
436
+ Args:
437
+ dataset_name: Name of the dataset file
438
+ category: Optional category (cot_reasoning, professional_therapeutic, etc.)
439
+ bucket: S3 bucket name
440
+ prefer_processed: Prefer processed/canonical structure over raw
441
+
442
+ Returns:
443
+ S3 path (s3://bucket/path)
444
+ """
445
+ loader = S3DatasetLoader(bucket=bucket)
446
+
447
+ # Try canonical processed structure first
448
+ if category and prefer_processed:
449
+ path = f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}"
450
+ if loader.object_exists(path):
451
+ return path
452
+
453
+ # Fallback to raw structure
454
+ if prefer_processed:
455
+ path = f"s3://{bucket}/gdrive/raw/{dataset_name}"
456
+ if loader.object_exists(path):
457
+ return path
458
+
459
+ # Fallback to acquired
460
+ path = f"s3://{bucket}/acquired/{dataset_name}"
461
+ if loader.object_exists(path):
462
+ return path
463
+
464
+ # If category provided, construct path even if doesn't exist yet
465
+ if category:
466
+ return f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}"
467
+
468
+ return f"s3://{bucket}/gdrive/raw/{dataset_name}"
469
+
470
+
471
+ def load_dataset_from_s3(
472
+ dataset_name: str,
473
+ category: str | None = None,
474
+ cache_local: Path | None = None,
475
+ bucket: str = "pixel-data",
476
+ ) -> dict[str, Any]:
477
+ """
478
+ Load dataset from S3 with automatic path resolution.
479
+
480
+ Args:
481
+ dataset_name: Name of the dataset file
482
+ category: Optional category for canonical structure
483
+ cache_local: Optional local cache path
484
+ bucket: S3 bucket name
485
+
486
+ Returns:
487
+ Dataset data
488
+ """
489
+ loader = S3DatasetLoader(bucket=bucket)
490
+ s3_path = get_s3_dataset_path(dataset_name, category, bucket)
491
+
492
+ if dataset_name.endswith(".jsonl"):
493
+ # For JSONL, convert to list
494
+ return {"conversations": list(loader.stream_jsonl(s3_path))}
495
+ return loader.load_json(s3_path, cache_local)
utils/subtitle_processor.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
+
4
+
5
+ class SubtitleProcessor:
6
+ """Utility for cleaning and formatting YouTube VTT subtitles."""
7
+
8
+ @staticmethod
9
+ def clean_vtt(vtt_content: str) -> str:
10
+ """
11
+ Clean VTT content by removing timestamps, tags, and duplicates.
12
+ YouTube automatic captions often repeat lines with incremental words.
13
+ """
14
+ # Remove header
15
+ lines = vtt_content.split("\n")
16
+ if lines and lines[0].startswith("WEBVTT"):
17
+ lines = lines[1:]
18
+
19
+ # Remove metadata lines (Kind:, Language:, etc)
20
+ lines = [
21
+ line
22
+ for line in lines
23
+ if not any(
24
+ line.startswith(prefix)
25
+ for prefix in ["Kind:", "Language:", "align:", "position:"]
26
+ )
27
+ ]
28
+
29
+ # Remove timestamp lines and tags
30
+ # Pattern for 00:00:00.000 --> 00:00:00.000
31
+ timestamp_pattern = re.compile(
32
+ r"\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}.*"
33
+ )
34
+ # Pattern for <00:00:00.000><c> etc
35
+ tag_pattern = re.compile(r"<[^>]+>")
36
+
37
+ cleaned_paragraphs = []
38
+ current_text = []
39
+
40
+ seen_lines = set()
41
+
42
+ for line in lines:
43
+ line = line.strip()
44
+ if not line:
45
+ continue
46
+
47
+ if timestamp_pattern.match(line):
48
+ continue
49
+
50
+ # Clean tags
51
+ cleaned_line = tag_pattern.sub("", line).strip()
52
+
53
+ if not cleaned_line:
54
+ continue
55
+
56
+ # YouTube auto-subs repeat text heavily.
57
+ # We want to keep unique sentences/segments.
58
+ if cleaned_line in seen_lines:
59
+ continue
60
+
61
+ seen_lines.add(cleaned_line)
62
+ current_text.append(cleaned_line)
63
+
64
+ # Merge lines and remove redundant parts of sentences
65
+ full_text = " ".join(current_text)
66
+
67
+ # Simple cleanup of redundant repeated segments (YouTube specific)
68
+ # e.g. "Hello world Hello world there" -> "Hello world there"
69
+ # This is a bit complex to do perfectly without NLP, but we can do some basics.
70
+
71
+ return full_text
72
+
73
+ @staticmethod
74
+ def format_as_markdown(text: str, metadata: dict) -> str:
75
+ """Format the cleaned text as a structured Markdown file."""
76
+ title = metadata.get("title", "Unknown Title")
77
+ channel = metadata.get("channel", "Unknown Channel")
78
+ video_url = metadata.get("url", "")
79
+ date = metadata.get("date", "")
80
+
81
+ md = f"# {title}\n\n"
82
+ md += f"**Channel:** {channel}\n"
83
+ md += f"**Source:** {video_url}\n"
84
+ md += f"**Date:** {date}\n\n"
85
+ md += "## Transcript\n\n"
86
+
87
+ # Split into paragraphs of roughly 5-7 sentences
88
+ sentences = re.split(r"(?<=[.!?])\s+", text)
89
+ paragraphs = []
90
+ for i in range(0, len(sentences), 6):
91
+ paragraphs.append(" ".join(sentences[i : i + 6]))
92
+
93
+ md += "\n\n".join(paragraphs)
94
+ md += "\n"
95
+
96
+ return md
utils/transcript_corrector.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ # Configure logger
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class TranscriptCorrector:
12
+ """
13
+ Utility class for correcting transcripts using a multi-pass approach:
14
+ 1. Therapeutic Terminology Validation
15
+ 2. LLM-based Contextual Correction (Mocked for now)
16
+ 3. Structural Alignment (Basic regex cleanup)
17
+ """
18
+
19
+ def __init__(self, config_path: str = "ai/config/therapeutic_terminology.json"):
20
+ """
21
+ Initialize the TranscriptCorrector with terminology configuration.
22
+
23
+ Args:
24
+ config_path: Path to the JSON configuration file containing
25
+ therapeutic terms.
26
+ """
27
+ self.config_path = Path(config_path)
28
+ self.terms: Dict[str, Any] = self._load_terminology()
29
+
30
+ def _load_terminology(self) -> Dict[str, Any]:
31
+ """Load therapeutic terminology from JSON config."""
32
+ try:
33
+ # Handle relative paths from project root if needed
34
+ if not self.config_path.exists():
35
+ # Try relative to the current file location
36
+ # structure is usually ai/utils/transcript_corrector.py
37
+ # config is at ai/config/therapeutic_terminology.json
38
+ # so we go up 2 levels
39
+ base_path = Path(__file__).parent.parent
40
+ alt_path = base_path / "config" / "therapeutic_terminology.json"
41
+
42
+ if alt_path.exists():
43
+ self.config_path = alt_path
44
+ else:
45
+ logger.warning(
46
+ f"Terminology config not found at {self.config_path} or "
47
+ f"{alt_path}. Using empty config."
48
+ )
49
+ return {
50
+ "cptsd_terms": [],
51
+ "medical_terms": [],
52
+ "common_misinterpretations": {},
53
+ }
54
+
55
+ with open(self.config_path, "r", encoding="utf-8") as f:
56
+ return json.load(f)
57
+ except Exception as e:
58
+ logger.error(f"Failed to load terminology config: {e}")
59
+ return {
60
+ "cptsd_terms": [],
61
+ "medical_terms": [],
62
+ "common_misinterpretations": {},
63
+ }
64
+
65
+ def correct_transcript(self, text: str, context: str = "therapy_session") -> str:
66
+ """
67
+ Main entry point for transcript correction.
68
+
69
+ Args:
70
+ text: Single string containing the transcript text to correct.
71
+ context: Context hint for LLM correction.
72
+
73
+ Returns:
74
+ Corrected transcript text.
75
+ """
76
+ if not text or not text.strip():
77
+ return ""
78
+
79
+ # Pass 1: Basic Structural Cleanup
80
+ text = self._clean_structure(text)
81
+
82
+ # Pass 2: Terminology Replacement
83
+ text = self._apply_terminology_fixes(text)
84
+
85
+ # Pass 3: LLM Contextual Correction (Mocked)
86
+ text = self._llm_contextual_correction(text, context)
87
+
88
+ return text
89
+
90
+ def _clean_structure(self, text: str) -> str:
91
+ """Remove filler words and normalize whitespace."""
92
+ # Common filler words in speech, optionally followed by a comma
93
+ fillers = r"\b(um|uh|err|ah|like|you know|I mean)\b,?\s*"
94
+
95
+ # Remove fillers (case-insensitive)
96
+ cleaned = re.sub(fillers, "", text, flags=re.IGNORECASE)
97
+
98
+ # Normalize whitespace (replace multiple spaces with single space)
99
+ cleaned = re.sub(r"\s+", " ", cleaned).strip()
100
+
101
+ return cleaned
102
+
103
+ def _apply_terminology_fixes(self, text: str) -> str:
104
+ """Apply deterministic terminology fixes from config."""
105
+ misinterpretations = self.terms.get("common_misinterpretations", {})
106
+
107
+ for bad_term, good_term in misinterpretations.items():
108
+ # Use word boundaries to match whole words/phrases ignoring case
109
+ pattern = re.compile(re.escape(bad_term), re.IGNORECASE)
110
+ text = pattern.sub(good_term, text)
111
+
112
+ return text
113
+
114
+ def _llm_contextual_correction(self, text: str, context: str) -> str:
115
+ """
116
+ Mock function for GPT-4 based correction.
117
+ In the future, this will call the LLM service to fix grammar and nuances.
118
+ """
119
+ # TODO: Implement actual LLM call via external service or local model
120
+ # For now, we just log that we would allow the LLM to process this
121
+ # and return the text as is (or maybe apply a dummy transformation
122
+ # for testing if needed)
123
+
124
+ # Simulating a check for critical CPTSD terms that might be missed
125
+ # If we had an LLM, we'd ask it: "Correct this transcript keeping CPTSD context
126
+ # in mind."
127
+
128
+ return text
129
+
130
+ def validate_term_coverage(self, text: str) -> Dict[str, float]:
131
+ """
132
+ Calculate metrics on how well the transcript effectively uses domain
133
+ terminology. Useful for validation pass.
134
+ """
135
+ cptsd_terms = {t.lower() for t in self.terms.get("cptsd_terms", [])}
136
+ medical_terms = {t.lower() for t in self.terms.get("medical_terms", [])}
137
+
138
+ text_lower = text.lower()
139
+
140
+ found_cptsd = sum(term in text_lower for term in cptsd_terms)
141
+ found_medical = sum(term in text_lower for term in medical_terms)
142
+
143
+ total_domain_terms = len(cptsd_terms) + len(medical_terms)
144
+ found_total = found_cptsd + found_medical
145
+
146
+ # This is a naive metric, just for basic validation
147
+ coverage_score = (
148
+ found_total / total_domain_terms if total_domain_terms > 0 else 0.0
149
+ )
150
+
151
+ return {
152
+ "cptsd_term_count": found_cptsd,
153
+ "medical_term_count": found_medical,
154
+ "domain_coverage_score": round(coverage_score, 4),
155
+ }