Upload folder using huggingface_hub
Browse files- utils/__init__.py +25 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/llm_capabilities.cpython-311.pyc +0 -0
- utils/__pycache__/ngc_cli.cpython-311.pyc +0 -0
- utils/__pycache__/ngc_resources.cpython-311.pyc +0 -0
- utils/__pycache__/s3_dataset_loader.cpython-311.pyc +0 -0
- utils/__pycache__/transcript_corrector.cpython-311.pyc +0 -0
- utils/llm_capabilities.py +68 -0
- utils/ngc_cli.py +400 -0
- utils/ngc_resources.py +184 -0
- utils/s3_dataset_loader.py +495 -0
- utils/subtitle_processor.py +96 -0
- utils/transcript_corrector.py +155 -0
utils/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared AI utilities for training and dataset pipelines.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .ngc_cli import (
|
| 6 |
+
NGCCLI,
|
| 7 |
+
NGCCLIAuthError,
|
| 8 |
+
NGCCLIDownloadError,
|
| 9 |
+
NGCCLIError,
|
| 10 |
+
NGCCLINotFoundError,
|
| 11 |
+
NGCConfig,
|
| 12 |
+
ensure_ngc_cli_configured,
|
| 13 |
+
get_ngc_cli,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"NGCCLI",
|
| 18 |
+
"NGCCLIAuthError",
|
| 19 |
+
"NGCCLIDownloadError",
|
| 20 |
+
"NGCCLIError",
|
| 21 |
+
"NGCCLINotFoundError",
|
| 22 |
+
"NGCConfig",
|
| 23 |
+
"ensure_ngc_cli_configured",
|
| 24 |
+
"get_ngc_cli",
|
| 25 |
+
]
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (584 Bytes). View file
|
|
|
utils/__pycache__/llm_capabilities.cpython-311.pyc
ADDED
|
Binary file (3.35 kB). View file
|
|
|
utils/__pycache__/ngc_cli.cpython-311.pyc
ADDED
|
Binary file (17.8 kB). View file
|
|
|
utils/__pycache__/ngc_resources.cpython-311.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
utils/__pycache__/s3_dataset_loader.cpython-311.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
utils/__pycache__/transcript_corrector.cpython-311.pyc
ADDED
|
Binary file (7.8 kB). View file
|
|
|
utils/llm_capabilities.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from google import genai
|
| 4 |
+
|
| 5 |
+
_WORKING_MODEL_CACHE = None
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_best_available_gemini_model(client: genai.Client) -> str:
|
| 9 |
+
"""
|
| 10 |
+
Dynamically interrogates the Gemini API to find the best functioning
|
| 11 |
+
model available for the current API key's tier/region. This prevents
|
| 12 |
+
hardcoded models from throwing 404s if they are restricted.
|
| 13 |
+
"""
|
| 14 |
+
global _WORKING_MODEL_CACHE
|
| 15 |
+
if _WORKING_MODEL_CACHE:
|
| 16 |
+
return _WORKING_MODEL_CACHE
|
| 17 |
+
|
| 18 |
+
target_models = [
|
| 19 |
+
"models/gemini-2.0-flash-001",
|
| 20 |
+
"models/gemini-2.0-flash-lite-001",
|
| 21 |
+
"models/gemini-flash-latest",
|
| 22 |
+
"models/gemini-pro-latest",
|
| 23 |
+
"models/gemini-2.5-flash",
|
| 24 |
+
"models/gemini-2.5-pro",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
available_models = [m.name for m in client.models.list()]
|
| 29 |
+
print(f"DISCOVERED MODELS on this key: {available_models}")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Failed to list models: {e}")
|
| 32 |
+
return "gemini-1.5-flash" # Fallback guess
|
| 33 |
+
|
| 34 |
+
for target in target_models:
|
| 35 |
+
for available in available_models:
|
| 36 |
+
if target == available or available.endswith(target):
|
| 37 |
+
# Double check that we can actually invoke it
|
| 38 |
+
# (some show up in list but 404 on invoke due to constraints)
|
| 39 |
+
try:
|
| 40 |
+
client.models.generate_content(model=target, contents="ping")
|
| 41 |
+
_WORKING_MODEL_CACHE = target
|
| 42 |
+
print(f"Dynamically locked to functioning Gemini model: {target}")
|
| 43 |
+
return target
|
| 44 |
+
except Exception as eval_e:
|
| 45 |
+
print(f"Model {target} is listed but uninvokeable: {eval_e}")
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
print(
|
| 49 |
+
"CRITICAL WARNING: No preferred Gemini models available on this API Key. "
|
| 50 |
+
"Falling back to gemini-flash-latest."
|
| 51 |
+
)
|
| 52 |
+
return "models/gemini-flash-latest"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def ensure_valid_key() -> str:
|
| 56 |
+
"""Validates that the Gemini API key provided is a REST key, not an OAuth token."""
|
| 57 |
+
key = os.environ.get("GOOGLE_CLOUD_API_KEY") or os.environ.get("GEMINI_API_KEY")
|
| 58 |
+
if not key:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
"Neither GOOGLE_CLOUD_API_KEY nor GEMINI_API_KEY are configured."
|
| 61 |
+
)
|
| 62 |
+
if key.startswith("AQ"):
|
| 63 |
+
raise ValueError(
|
| 64 |
+
"Provided GEMINI_API_KEY is an OAuth token (AQ...). "
|
| 65 |
+
"The AI engine requires a Google Cloud REST API key (AIza...). "
|
| 66 |
+
"Please update your .env file."
|
| 67 |
+
)
|
| 68 |
+
return key
|
utils/ngc_cli.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NGC CLI Utility Module
|
| 3 |
+
|
| 4 |
+
Provides utilities for working with NVIDIA GPU Cloud (NGC) CLI to download
|
| 5 |
+
NeMo resources, datasets, and other NGC catalog resources.
|
| 6 |
+
|
| 7 |
+
This module handles:
|
| 8 |
+
- NGC CLI detection and installation
|
| 9 |
+
- Resource download from NGC catalog
|
| 10 |
+
- Configuration management
|
| 11 |
+
- Error handling and retry logic
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import subprocess
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class NGCConfig:
|
| 27 |
+
"""NGC CLI configuration"""
|
| 28 |
+
|
| 29 |
+
api_key: str | None = None
|
| 30 |
+
org: str | None = None
|
| 31 |
+
team: str | None = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NGCCLIError(Exception):
|
| 35 |
+
"""Base exception for NGC CLI operations"""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class NGCCLINotFoundError(NGCCLIError):
|
| 39 |
+
"""NGC CLI not found or not installed"""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NGCCLIAuthError(NGCCLIError):
|
| 43 |
+
"""NGC CLI authentication error"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class NGCCLIDownloadError(NGCCLIError):
|
| 47 |
+
"""NGC CLI download error"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class NGCCLI:
|
| 51 |
+
"""
|
| 52 |
+
NGC CLI wrapper for downloading resources from NVIDIA GPU Cloud.
|
| 53 |
+
|
| 54 |
+
Supports multiple installation methods:
|
| 55 |
+
1. System-installed ngc in PATH
|
| 56 |
+
2. Local installation at ~/ngc-cli/ngc
|
| 57 |
+
3. Python package via uv (ngc-python-cli)
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, use_uv: bool = True):
|
| 61 |
+
"""
|
| 62 |
+
Initialize NGC CLI wrapper.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
use_uv: If True, prefer uv-based installation if ngc not in PATH
|
| 66 |
+
"""
|
| 67 |
+
self.use_uv = use_uv
|
| 68 |
+
self.ngc_cmd: str | None = None
|
| 69 |
+
self.uv_cmd: str | None = None
|
| 70 |
+
self._detect_ngc_cli()
|
| 71 |
+
|
| 72 |
+
def _detect_ngc_cli(self) -> None:
|
| 73 |
+
"""Detect and set up NGC CLI command"""
|
| 74 |
+
# Method 1: Check if ngc is in PATH
|
| 75 |
+
if shutil.which("ngc"):
|
| 76 |
+
self.ngc_cmd = "ngc"
|
| 77 |
+
logger.info("Found NGC CLI in PATH")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
# Method 2: Check common installation location
|
| 81 |
+
home_ngc = Path.home() / "ngc-cli" / "ngc"
|
| 82 |
+
if home_ngc.exists():
|
| 83 |
+
self.ngc_cmd = str(home_ngc)
|
| 84 |
+
# Add to PATH for subprocess calls
|
| 85 |
+
env_path = os.environ.get("PATH", "")
|
| 86 |
+
os.environ["PATH"] = f"{home_ngc.parent}:{env_path}"
|
| 87 |
+
logger.info(f"Found NGC CLI at {home_ngc}")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
# Method 3: Use uv to run ngc (if enabled)
|
| 91 |
+
if self.use_uv:
|
| 92 |
+
self._setup_uv_ngc()
|
| 93 |
+
|
| 94 |
+
def _setup_uv_ngc(self) -> None:
|
| 95 |
+
"""Set up NGC CLI via uv"""
|
| 96 |
+
# Find uv
|
| 97 |
+
if shutil.which("uv"):
|
| 98 |
+
self.uv_cmd = "uv"
|
| 99 |
+
elif (Path.home() / ".local" / "bin" / "uv").exists():
|
| 100 |
+
self.uv_cmd = str(Path.home() / ".local" / "bin" / "uv")
|
| 101 |
+
elif (Path.home() / ".cargo" / "bin" / "uv").exists():
|
| 102 |
+
self.uv_cmd = str(Path.home() / ".cargo" / "bin" / "uv")
|
| 103 |
+
else:
|
| 104 |
+
logger.warning("uv not found, cannot use uv-based NGC CLI")
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
# Check if ngc is installed via uv
|
| 108 |
+
try:
|
| 109 |
+
result = subprocess.run(
|
| 110 |
+
[self.uv_cmd, "pip", "list"],
|
| 111 |
+
capture_output=True,
|
| 112 |
+
text=True,
|
| 113 |
+
check=False,
|
| 114 |
+
)
|
| 115 |
+
if "ngc" in result.stdout.lower():
|
| 116 |
+
self.ngc_cmd = f"{self.uv_cmd} run ngc"
|
| 117 |
+
logger.info("Found NGC CLI via uv")
|
| 118 |
+
return
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.debug(f"Error checking uv packages: {e}")
|
| 121 |
+
|
| 122 |
+
# Note: NGC CLI is not a Python package on PyPI
|
| 123 |
+
# It must be downloaded from https://catalog.ngc.nvidia.com
|
| 124 |
+
# We can only check if it's available in PATH or local installation
|
| 125 |
+
# The uv method here is for running Python-based NGC SDK if available
|
| 126 |
+
logger.debug("NGC CLI must be installed separately from NVIDIA website")
|
| 127 |
+
|
| 128 |
+
def is_available(self) -> bool:
|
| 129 |
+
"""Check if NGC CLI is available"""
|
| 130 |
+
return self.ngc_cmd is not None
|
| 131 |
+
|
| 132 |
+
def ensure_available(self) -> None:
|
| 133 |
+
"""Ensure NGC CLI is available, raise error if not"""
|
| 134 |
+
if not self.is_available():
|
| 135 |
+
raise NGCCLINotFoundError(
|
| 136 |
+
"NGC CLI not found. Please install it:\n"
|
| 137 |
+
" 1. Download from https://catalog.ngc.nvidia.com\n"
|
| 138 |
+
" 2. Or install to ~/ngc-cli/ directory\n"
|
| 139 |
+
" 3. Or add to system PATH\n"
|
| 140 |
+
"\n"
|
| 141 |
+
"Note: NGC CLI is not available as a PyPI package.\n"
|
| 142 |
+
"You must download it directly from NVIDIA."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def check_config(self) -> dict[str, Any]:
|
| 146 |
+
"""
|
| 147 |
+
Check NGC CLI configuration.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Configuration dictionary with API key status, org, team, etc.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
NGCCLINotFoundError: If NGC CLI is not available
|
| 154 |
+
NGCCLIAuthError: If authentication is not configured
|
| 155 |
+
"""
|
| 156 |
+
self.ensure_available()
|
| 157 |
+
|
| 158 |
+
if self.ngc_cmd is None:
|
| 159 |
+
raise NGCCLINotFoundError("NGC CLI command not set")
|
| 160 |
+
try:
|
| 161 |
+
result = subprocess.run(
|
| 162 |
+
[*self.ngc_cmd.split(), "config", "current"],
|
| 163 |
+
capture_output=True,
|
| 164 |
+
text=True,
|
| 165 |
+
check=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
config = {}
|
| 169 |
+
# Parse the table format output
|
| 170 |
+
lines = result.stdout.strip().split("\n")
|
| 171 |
+
current_key = None
|
| 172 |
+
|
| 173 |
+
for line in lines:
|
| 174 |
+
if "|" in line and "| key " not in line.lower() and "---" not in line:
|
| 175 |
+
parts = [part.strip() for part in line.split("|") if part.strip()]
|
| 176 |
+
if len(parts) >= 3: # key | value | source
|
| 177 |
+
key, value, source = parts[0], parts[1], parts[2]
|
| 178 |
+
if key: # New key
|
| 179 |
+
current_key = key
|
| 180 |
+
config[key] = value
|
| 181 |
+
elif current_key and value: # Continuation of previous key
|
| 182 |
+
config[current_key] += value
|
| 183 |
+
elif len(parts) == 1 and current_key: # Just a value continuation
|
| 184 |
+
config[current_key] += parts[0]
|
| 185 |
+
|
| 186 |
+
# Check if API key is configured (it will be masked with asterisks)
|
| 187 |
+
# If we have any config and apikey exists (even masked), consider it configured
|
| 188 |
+
if config and ("apikey" in config or "API key" in config):
|
| 189 |
+
return config
|
| 190 |
+
|
| 191 |
+
raise NGCCLIAuthError(
|
| 192 |
+
"NGC CLI not configured. Run: ngc config set\n"
|
| 193 |
+
"Get your API key from: https://catalog.ngc.nvidia.com"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return config
|
| 197 |
+
except subprocess.CalledProcessError as e:
|
| 198 |
+
raise NGCCLIAuthError(f"Failed to check NGC config: {e.stderr}") from e
|
| 199 |
+
|
| 200 |
+
def set_config(
|
| 201 |
+
self, api_key: str, _org: str | None = None, _team: str | None = None
|
| 202 |
+
) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Configure NGC CLI with API key.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
api_key: NGC API key from https://catalog.ngc.nvidia.com
|
| 208 |
+
_org: Optional organization name (reserved for future use)
|
| 209 |
+
_team: Optional team name (reserved for future use)
|
| 210 |
+
"""
|
| 211 |
+
self.ensure_available()
|
| 212 |
+
|
| 213 |
+
if self.ngc_cmd is None:
|
| 214 |
+
raise NGCCLINotFoundError("NGC CLI command not set")
|
| 215 |
+
# Set API key
|
| 216 |
+
try:
|
| 217 |
+
subprocess.run(
|
| 218 |
+
[*self.ngc_cmd.split(), "config", "set"],
|
| 219 |
+
input=f"{api_key}\n",
|
| 220 |
+
text=True,
|
| 221 |
+
check=True,
|
| 222 |
+
capture_output=True,
|
| 223 |
+
)
|
| 224 |
+
logger.info("NGC CLI configured successfully")
|
| 225 |
+
except subprocess.CalledProcessError as e:
|
| 226 |
+
raise NGCCLIAuthError(f"Failed to configure NGC CLI: {e.stderr}") from e
|
| 227 |
+
|
| 228 |
+
def download_resource(
|
| 229 |
+
self,
|
| 230 |
+
resource_path: str,
|
| 231 |
+
version: str | None = None,
|
| 232 |
+
output_dir: Path | None = None,
|
| 233 |
+
extract: bool = True, # noqa: ARG002
|
| 234 |
+
) -> Path:
|
| 235 |
+
"""
|
| 236 |
+
Download a resource from NGC catalog.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
resource_path: Resource path in format "org/team/resource" or "nvidia/nemo-microservices/nemo-microservices-quickstart"
|
| 240 |
+
version: Optional version tag (e.g., "25.10")
|
| 241 |
+
output_dir: Optional output directory (defaults to current directory)
|
| 242 |
+
extract: Whether to extract downloaded archive
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Path to downloaded/extracted resource
|
| 246 |
+
|
| 247 |
+
Raises:
|
| 248 |
+
NGCCLINotFoundError: If NGC CLI is not available
|
| 249 |
+
NGCCLIAuthError: If authentication failed
|
| 250 |
+
NGCCLIDownloadError: If download failed
|
| 251 |
+
"""
|
| 252 |
+
self.ensure_available()
|
| 253 |
+
|
| 254 |
+
# Check config first
|
| 255 |
+
try:
|
| 256 |
+
self.check_config()
|
| 257 |
+
except NGCCLIAuthError:
|
| 258 |
+
logger.warning("NGC CLI not configured. Attempting download anyway...")
|
| 259 |
+
|
| 260 |
+
if output_dir is None:
|
| 261 |
+
output_dir = Path.cwd()
|
| 262 |
+
else:
|
| 263 |
+
output_dir = Path(output_dir)
|
| 264 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 265 |
+
|
| 266 |
+
if self.ngc_cmd is None:
|
| 267 |
+
raise NGCCLINotFoundError("NGC CLI command not set")
|
| 268 |
+
# Build download command
|
| 269 |
+
cmd = [*self.ngc_cmd.split(), "registry", "resource", "download-version"]
|
| 270 |
+
|
| 271 |
+
resource_spec = f"{resource_path}:{version}" if version else resource_path
|
| 272 |
+
|
| 273 |
+
cmd.append(resource_spec)
|
| 274 |
+
|
| 275 |
+
# Change to output directory for download
|
| 276 |
+
original_cwd = Path.cwd()
|
| 277 |
+
try:
|
| 278 |
+
return self._execute_download_in_directory(output_dir, resource_spec, cmd)
|
| 279 |
+
finally:
|
| 280 |
+
os.chdir(original_cwd)
|
| 281 |
+
|
| 282 |
+
def _execute_download_in_directory(
|
| 283 |
+
self, output_dir: Path, resource_spec: str, cmd: list[str]
|
| 284 |
+
) -> Path:
|
| 285 |
+
"""
|
| 286 |
+
Execute download command in the specified directory and locate the downloaded resource.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
output_dir: Directory to download into
|
| 290 |
+
resource_spec: Resource specification string for logging
|
| 291 |
+
cmd: Command to execute
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Path to the downloaded resource (most recently modified item, or output_dir if empty)
|
| 295 |
+
|
| 296 |
+
Raises:
|
| 297 |
+
NGCCLIDownloadError: If download fails
|
| 298 |
+
"""
|
| 299 |
+
os.chdir(output_dir)
|
| 300 |
+
logger.info(f"Downloading {resource_spec} to {output_dir}...")
|
| 301 |
+
|
| 302 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
| 303 |
+
|
| 304 |
+
if result.returncode != 0:
|
| 305 |
+
error_msg = result.stderr or result.stdout
|
| 306 |
+
raise NGCCLIDownloadError(
|
| 307 |
+
f"Failed to download {resource_spec}:\n{error_msg}"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
logger.info(f"Successfully downloaded {resource_spec}")
|
| 311 |
+
|
| 312 |
+
if downloaded_items := list(output_dir.iterdir()):
|
| 313 |
+
# Return the most recently modified item
|
| 314 |
+
return max(downloaded_items, key=lambda p: p.stat().st_mtime)
|
| 315 |
+
|
| 316 |
+
return output_dir
|
| 317 |
+
|
| 318 |
+
def list_resources(
|
| 319 |
+
self, org: str | None = None, team: str | None = None
|
| 320 |
+
) -> list[dict[str, Any]]:
|
| 321 |
+
"""
|
| 322 |
+
List available resources in NGC catalog.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
org: Optional organization filter
|
| 326 |
+
team: Optional team filter
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
List of resource dictionaries
|
| 330 |
+
"""
|
| 331 |
+
self.ensure_available()
|
| 332 |
+
|
| 333 |
+
if self.ngc_cmd is None:
|
| 334 |
+
raise NGCCLINotFoundError("NGC CLI command not set")
|
| 335 |
+
cmd = [*self.ngc_cmd.split(), "registry", "resource", "list"]
|
| 336 |
+
|
| 337 |
+
if org:
|
| 338 |
+
cmd.extend(["--org", org])
|
| 339 |
+
if team:
|
| 340 |
+
cmd.extend(["--team", team])
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 344 |
+
|
| 345 |
+
# Parse output (format may vary)
|
| 346 |
+
# TODO: Implement proper parsing based on actual NGC CLI output format
|
| 347 |
+
return []
|
| 348 |
+
except subprocess.CalledProcessError as e:
|
| 349 |
+
logger.warning(f"Failed to list resources: {e.stderr}")
|
| 350 |
+
return []
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_ngc_cli(use_uv: bool = True) -> NGCCLI:
|
| 354 |
+
"""
|
| 355 |
+
Get an NGC CLI instance.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
use_uv: If True, prefer uv-based installation
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
NGCCLI instance
|
| 362 |
+
"""
|
| 363 |
+
return NGCCLI(use_uv=use_uv)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def ensure_ngc_cli_configured(api_key: str | None = None) -> NGCCLI:
|
| 367 |
+
"""
|
| 368 |
+
Ensure NGC CLI is available and configured.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
api_key: Optional API key to configure (if not already configured)
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Configured NGCCLI instance
|
| 375 |
+
|
| 376 |
+
Raises:
|
| 377 |
+
NGCCLINotFoundError: If NGC CLI cannot be found or installed
|
| 378 |
+
NGCCLIAuthError: If configuration fails
|
| 379 |
+
"""
|
| 380 |
+
cli = get_ngc_cli()
|
| 381 |
+
|
| 382 |
+
if not cli.is_available():
|
| 383 |
+
raise NGCCLINotFoundError(
|
| 384 |
+
"NGC CLI not available. Install with:\n"
|
| 385 |
+
" uv pip install nvidia-pyindex nvidia-nim ngc-python-cli\n"
|
| 386 |
+
"Or download from: https://catalog.ngc.nvidia.com"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Check if already configured
|
| 390 |
+
try:
|
| 391 |
+
cli.check_config()
|
| 392 |
+
return cli
|
| 393 |
+
except NGCCLIAuthError as err:
|
| 394 |
+
if api_key:
|
| 395 |
+
cli.set_config(api_key)
|
| 396 |
+
return cli
|
| 397 |
+
raise NGCCLIAuthError(
|
| 398 |
+
"NGC CLI not configured. Provide API key or run: ngc config set\n"
|
| 399 |
+
"Get API key from: https://catalog.ngc.nvidia.com"
|
| 400 |
+
) from err
|
utils/ngc_resources.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NGC Resources Downloader for Training Ready
|
| 3 |
+
|
| 4 |
+
Downloads NeMo resources and training-related assets from NGC catalog.
|
| 5 |
+
Integrates with training_ready pipeline for automated resource acquisition.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add parent directory to path for imports
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 15 |
+
|
| 16 |
+
from ai.utils.ngc_cli import (
|
| 17 |
+
NGCCLIAuthError,
|
| 18 |
+
NGCCLINotFoundError,
|
| 19 |
+
ensure_ngc_cli_configured,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NGCResourceDownloader:
|
| 26 |
+
"""
|
| 27 |
+
Downloads NeMo and training resources from NGC catalog.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# Common NeMo resources used in training
|
| 31 |
+
NEMO_RESOURCES = {
|
| 32 |
+
"nemo-microservices-quickstart": {
|
| 33 |
+
"path": "nvidia/nemo-microservices/nemo-microservices-quickstart",
|
| 34 |
+
"default_version": "25.10",
|
| 35 |
+
"description": "NeMo Microservices quickstart package",
|
| 36 |
+
},
|
| 37 |
+
"nemo-framework": {
|
| 38 |
+
"path": "nvidia/nemo/nemo",
|
| 39 |
+
"default_version": None, # Use latest
|
| 40 |
+
"description": "NeMo framework for training",
|
| 41 |
+
},
|
| 42 |
+
"nemo-megatron": {
|
| 43 |
+
"path": "nvidia/nemo/nemo-megatron",
|
| 44 |
+
"default_version": None,
|
| 45 |
+
"description": "NeMo Megatron for large-scale training",
|
| 46 |
+
},
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def __init__(self, output_base: Path | None = None, api_key: str | None = None):
|
| 50 |
+
"""
|
| 51 |
+
Initialize NGC resource downloader.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
output_base: Base directory for downloads (defaults to training_ready/resources/)
|
| 55 |
+
api_key: Optional NGC API key (if not set, will check environment or prompt)
|
| 56 |
+
"""
|
| 57 |
+
if output_base is None:
|
| 58 |
+
output_base = Path(__file__).parent.parent / "resources"
|
| 59 |
+
self.output_base = Path(output_base)
|
| 60 |
+
self.output_base.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Get API key from environment if not provided
|
| 63 |
+
if api_key is None:
|
| 64 |
+
api_key = os.environ.get("NGC_API_KEY")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
self.cli = ensure_ngc_cli_configured(api_key=api_key)
|
| 68 |
+
except (NGCCLINotFoundError, NGCCLIAuthError) as e:
|
| 69 |
+
logger.warning(f"NGC CLI not available: {e}")
|
| 70 |
+
|
| 71 |
+
self.cli = None
|
| 72 |
+
|
| 73 |
+
def download_nemo_quickstart(
|
| 74 |
+
self, version: str | None = None, output_dir: Path | None = None
|
| 75 |
+
) -> Path:
|
| 76 |
+
"""
|
| 77 |
+
Download NeMo Microservices quickstart package.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
version: Version to download (defaults to 25.10)
|
| 81 |
+
output_dir: Output directory (defaults to resources/nemo-microservices/)
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Path to downloaded/extracted quickstart directory
|
| 85 |
+
"""
|
| 86 |
+
if not self.cli:
|
| 87 |
+
raise NGCCLINotFoundError("NGC CLI not available")
|
| 88 |
+
|
| 89 |
+
if version is None:
|
| 90 |
+
version = self.NEMO_RESOURCES["nemo-microservices-quickstart"][
|
| 91 |
+
"default_version"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
if output_dir is None:
|
| 95 |
+
output_dir = self.output_base / "nemo-microservices"
|
| 96 |
+
|
| 97 |
+
resource_path = self.NEMO_RESOURCES["nemo-microservices-quickstart"]["path"]
|
| 98 |
+
|
| 99 |
+
logger.info(f"Downloading NeMo Microservices quickstart v{version}...")
|
| 100 |
+
return self.cli.download_resource(
|
| 101 |
+
resource_path=resource_path,
|
| 102 |
+
version=version,
|
| 103 |
+
output_dir=output_dir,
|
| 104 |
+
extract=True,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def download_nemo_framework(
|
| 108 |
+
self, version: str | None = None, output_dir: Path | None = None
|
| 109 |
+
) -> Path:
|
| 110 |
+
"""
|
| 111 |
+
Download NeMo framework.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
version: Version to download
|
| 115 |
+
output_dir: Output directory
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Path to downloaded framework
|
| 119 |
+
"""
|
| 120 |
+
if not self.cli:
|
| 121 |
+
raise NGCCLINotFoundError("NGC CLI not available")
|
| 122 |
+
|
| 123 |
+
if output_dir is None:
|
| 124 |
+
output_dir = self.output_base / "nemo-framework"
|
| 125 |
+
|
| 126 |
+
resource_path = self.NEMO_RESOURCES["nemo-framework"]["path"]
|
| 127 |
+
|
| 128 |
+
logger.info("Downloading NeMo framework...")
|
| 129 |
+
return self.cli.download_resource(
|
| 130 |
+
resource_path=resource_path,
|
| 131 |
+
version=version,
|
| 132 |
+
output_dir=output_dir,
|
| 133 |
+
extract=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def download_custom_resource(
|
| 137 |
+
self,
|
| 138 |
+
resource_path: str,
|
| 139 |
+
version: str | None = None,
|
| 140 |
+
output_dir: Path | None = None,
|
| 141 |
+
) -> Path:
|
| 142 |
+
"""
|
| 143 |
+
Download a custom resource from NGC catalog.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
resource_path: Resource path (e.g., "nvidia/nemo-microservices/nemo-microservices-quickstart")
|
| 147 |
+
version: Optional version tag
|
| 148 |
+
output_dir: Optional output directory
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Path to downloaded resource
|
| 152 |
+
"""
|
| 153 |
+
if not self.cli:
|
| 154 |
+
raise NGCCLINotFoundError("NGC CLI not available")
|
| 155 |
+
|
| 156 |
+
if output_dir is None:
|
| 157 |
+
# Create directory from resource name
|
| 158 |
+
resource_name = resource_path.split("/")[-1]
|
| 159 |
+
output_dir = self.output_base / resource_name
|
| 160 |
+
|
| 161 |
+
logger.info(f"Downloading {resource_path}...")
|
| 162 |
+
return self.cli.download_resource(
|
| 163 |
+
resource_path=resource_path,
|
| 164 |
+
version=version,
|
| 165 |
+
output_dir=output_dir,
|
| 166 |
+
extract=True,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def download_nemo_quickstart(
|
| 171 |
+
version: str | None = None, output_dir: Path | None = None
|
| 172 |
+
) -> Path:
|
| 173 |
+
"""
|
| 174 |
+
Convenience function to download NeMo Microservices quickstart.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
version: Version to download (defaults to 25.10)
|
| 178 |
+
output_dir: Output directory
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Path to downloaded quickstart directory
|
| 182 |
+
"""
|
| 183 |
+
downloader = NGCResourceDownloader()
|
| 184 |
+
return downloader.download_nemo_quickstart(version=version, output_dir=output_dir)
|
utils/s3_dataset_loader.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
S3 Dataset Loader - Streaming JSON/JSONL loader for S3 training data
|
| 4 |
+
S3 is the training mecca - all training data should be loaded from S3
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from collections.abc import Iterator
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import TYPE_CHECKING, Any
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import boto3
|
| 18 |
+
from botocore.exceptions import ClientError as _BotocoreClientError
|
| 19 |
+
except ImportError:
|
| 20 |
+
# Keep runtime behavior (error on use) while making type checkers happy.
|
| 21 |
+
boto3 = None # type: ignore[assignment]
|
| 22 |
+
_BotocoreClientError = None # type: ignore[assignment]
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
# Minimal shape we rely on in this module.
|
| 26 |
+
class ClientError(Exception):
|
| 27 |
+
response: dict[str, Any]
|
| 28 |
+
else:
|
| 29 |
+
ClientError = (
|
| 30 |
+
_BotocoreClientError if _BotocoreClientError is not None else Exception
|
| 31 |
+
) # type: ignore[assignment]
|
| 32 |
+
|
| 33 |
+
BOTO3_AVAILABLE = boto3 is not None
|
| 34 |
+
|
| 35 |
+
# Load .env file if available
|
| 36 |
+
with contextlib.suppress(ImportError):
|
| 37 |
+
from dotenv import load_dotenv
|
| 38 |
+
|
| 39 |
+
# Try loading from ai/ directory first (where .env actually is), then project root
|
| 40 |
+
# Module is at: ai/training_ready/utils/s3_dataset_loader.py
|
| 41 |
+
# So parents[0] = ai/training_ready/utils/, parents[1] = ai/training_ready/,
|
| 42 |
+
# parents[2] = ai/, parents[3] = project root
|
| 43 |
+
module_path = Path(__file__).resolve()
|
| 44 |
+
env_paths = []
|
| 45 |
+
try:
|
| 46 |
+
env_paths.append(module_path.parents[2] / ".env") # ai/.env
|
| 47 |
+
env_paths.append(module_path.parents[3] / ".env") # project root/.env
|
| 48 |
+
except IndexError:
|
| 49 |
+
# Fallback for shallow/flattened structures
|
| 50 |
+
env_paths.append(module_path.parent / ".env")
|
| 51 |
+
if module_path.parent.name != "ai":
|
| 52 |
+
env_paths.append(module_path.parent.parent / ".env")
|
| 53 |
+
|
| 54 |
+
for env_path in env_paths:
|
| 55 |
+
try:
|
| 56 |
+
if env_path.exists() and env_path.is_file():
|
| 57 |
+
load_dotenv(env_path, override=False)
|
| 58 |
+
break
|
| 59 |
+
except Exception:
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
logger = logging.getLogger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class S3DatasetLoader:
|
| 66 |
+
"""
|
| 67 |
+
Load datasets from S3 with streaming support for large files.
|
| 68 |
+
S3 is the canonical training data location.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
bucket: str = "pixel-data",
|
| 74 |
+
endpoint_url: str | None = None,
|
| 75 |
+
aws_access_key_id: str | None = None,
|
| 76 |
+
aws_secret_access_key: str | None = None,
|
| 77 |
+
region_name: str = "us-east-va",
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Initialize S3 client for dataset loading.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
bucket: S3 bucket name (default: pixel-data)
|
| 84 |
+
endpoint_url: S3 endpoint URL (default: OVH S3 endpoint)
|
| 85 |
+
aws_access_key_id: AWS access key (from env if not provided)
|
| 86 |
+
aws_secret_access_key: AWS secret key (from env if not provided)
|
| 87 |
+
region_name: AWS region (default: us-east-va for OVH)
|
| 88 |
+
"""
|
| 89 |
+
if boto3 is None:
|
| 90 |
+
raise ImportError(
|
| 91 |
+
"boto3 is required for S3 dataset loading. "
|
| 92 |
+
"Install with: uv pip install boto3"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Always allow env to override bucket for OVH S3
|
| 96 |
+
# This ensures OVH_S3_BUCKET is always used when set
|
| 97 |
+
self.bucket = os.getenv("OVH_S3_BUCKET", bucket)
|
| 98 |
+
print(
|
| 99 |
+
f"[DEBUG] S3Loader: env OVH_S3_BUCKET={os.getenv('OVH_S3_BUCKET')}, "
|
| 100 |
+
f"input bucket={bucket}, final={self.bucket}",
|
| 101 |
+
flush=True,
|
| 102 |
+
)
|
| 103 |
+
self.endpoint_url = endpoint_url or os.getenv(
|
| 104 |
+
"OVH_S3_ENDPOINT", "https://s3.us-east-va.io.cloud.ovh.us"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Get credentials from params or environment
|
| 108 |
+
access_key = (
|
| 109 |
+
aws_access_key_id
|
| 110 |
+
or os.getenv("OVH_S3_ACCESS_KEY")
|
| 111 |
+
or os.getenv("OVH_ACCESS_KEY")
|
| 112 |
+
or os.getenv("AWS_ACCESS_KEY_ID")
|
| 113 |
+
)
|
| 114 |
+
secret_key = (
|
| 115 |
+
aws_secret_access_key
|
| 116 |
+
or os.getenv("OVH_S3_SECRET_KEY")
|
| 117 |
+
or os.getenv("OVH_SECRET_KEY")
|
| 118 |
+
or os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if not access_key or not secret_key:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"S3 credentials not found. Set OVH_S3_ACCESS_KEY/OVH_S3_SECRET_KEY "
|
| 124 |
+
"(or OVH_ACCESS_KEY/OVH_SECRET_KEY, "
|
| 125 |
+
"or AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Initialize S3 client (OVH S3 compatible)
|
| 129 |
+
# OVH uses self-signed certificates, so verify=False is required
|
| 130 |
+
# Initialize S3 client (OVH S3 compatible)
|
| 131 |
+
# OVH uses self-signed certificates, so verify=False is required
|
| 132 |
+
verify_ssl = os.getenv("OVH_S3_CA_BUNDLE", True)
|
| 133 |
+
# Handle string "False" or "0" from env
|
| 134 |
+
if str(verify_ssl).lower() in {"false", "0", "no"}:
|
| 135 |
+
verify_ssl = False
|
| 136 |
+
|
| 137 |
+
if verify_ssl is False:
|
| 138 |
+
logger.warning(
|
| 139 |
+
"Initializing S3 client with SSL verification DISABLED (insecure)"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.s3_client = boto3.client(
|
| 143 |
+
"s3",
|
| 144 |
+
endpoint_url=self.endpoint_url,
|
| 145 |
+
aws_access_key_id=access_key,
|
| 146 |
+
aws_secret_access_key=secret_key,
|
| 147 |
+
region_name=region_name or os.getenv("OVH_S3_REGION", "us-east-va"),
|
| 148 |
+
verify=verify_ssl,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
logger.info(f"S3DatasetLoader initialized for bucket: {bucket}")
|
| 152 |
+
|
| 153 |
+
def _parse_s3_path(self, s3_path: str) -> tuple[str, str]:
|
| 154 |
+
"""
|
| 155 |
+
Parse S3 path into bucket and key.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
s3_path: S3 path (s3://bucket/key or just key)
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Tuple of (bucket, key)
|
| 162 |
+
"""
|
| 163 |
+
# If path starts with s3://, it includes bucket
|
| 164 |
+
if s3_path.startswith("s3://"):
|
| 165 |
+
s3_path = s3_path.removeprefix("s3://")
|
| 166 |
+
if "/" in s3_path:
|
| 167 |
+
parts = s3_path.split("/", 1)
|
| 168 |
+
return parts[0], parts[1]
|
| 169 |
+
# s3://bucket-only (no key)
|
| 170 |
+
return s3_path, ""
|
| 171 |
+
|
| 172 |
+
# Otherwise, it's just a key - use configured bucket
|
| 173 |
+
return self.bucket, s3_path
|
| 174 |
+
|
| 175 |
+
def object_exists(self, s3_path: str) -> bool:
|
| 176 |
+
"""Check if S3 object exists"""
|
| 177 |
+
try:
|
| 178 |
+
bucket, key = self._parse_s3_path(s3_path)
|
| 179 |
+
self.s3_client.head_object(Bucket=bucket, Key=key)
|
| 180 |
+
return True
|
| 181 |
+
except ClientError as e:
|
| 182 |
+
if e.response["Error"]["Code"] == "404":
|
| 183 |
+
return False
|
| 184 |
+
raise
|
| 185 |
+
|
| 186 |
+
def load_json(
|
| 187 |
+
self,
|
| 188 |
+
s3_path: str,
|
| 189 |
+
cache_local: Path | None = None,
|
| 190 |
+
) -> dict[str, Any]:
|
| 191 |
+
"""
|
| 192 |
+
Load JSON dataset from S3.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
s3_path: S3 path (s3://bucket/key or just key)
|
| 196 |
+
cache_local: Optional local cache path
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Parsed JSON data
|
| 200 |
+
"""
|
| 201 |
+
bucket, key = self._parse_s3_path(s3_path)
|
| 202 |
+
|
| 203 |
+
# Check local cache first
|
| 204 |
+
if cache_local and cache_local.exists():
|
| 205 |
+
logger.info(f"Loading from local cache: {cache_local}")
|
| 206 |
+
with open(cache_local) as f:
|
| 207 |
+
return json.load(f)
|
| 208 |
+
|
| 209 |
+
# Load from S3
|
| 210 |
+
logger.info(f"Loading from S3: s3://{bucket}/{key}")
|
| 211 |
+
try:
|
| 212 |
+
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
| 213 |
+
data = json.loads(response["Body"].read())
|
| 214 |
+
|
| 215 |
+
# Cache locally if requested
|
| 216 |
+
if cache_local:
|
| 217 |
+
cache_local.parent.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
with open(cache_local, "w") as f:
|
| 219 |
+
json.dump(data, f)
|
| 220 |
+
logger.info(f"Cached to: {cache_local}")
|
| 221 |
+
|
| 222 |
+
return data
|
| 223 |
+
except ClientError as e:
|
| 224 |
+
if e.response["Error"]["Code"] == "NoSuchKey":
|
| 225 |
+
raise FileNotFoundError(
|
| 226 |
+
f"Dataset not found in S3: s3://{bucket}/{key}"
|
| 227 |
+
) from e
|
| 228 |
+
raise
|
| 229 |
+
|
| 230 |
+
def load_bytes(self, s3_path: str) -> bytes:
|
| 231 |
+
"""
|
| 232 |
+
Load raw bytes from S3.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
s3_path: S3 path (s3://bucket/key or just key)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Raw bytes of the object body
|
| 239 |
+
"""
|
| 240 |
+
bucket, key = self._parse_s3_path(s3_path)
|
| 241 |
+
logger.info(f"Loading bytes from S3: s3://{bucket}/{key}")
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
| 245 |
+
return response["Body"].read()
|
| 246 |
+
except ClientError as e:
|
| 247 |
+
if e.response["Error"]["Code"] == "NoSuchKey":
|
| 248 |
+
raise FileNotFoundError(
|
| 249 |
+
f"Dataset not found in S3: s3://{bucket}/{key}"
|
| 250 |
+
) from e
|
| 251 |
+
raise
|
| 252 |
+
|
| 253 |
+
def load_text(
|
| 254 |
+
self,
|
| 255 |
+
s3_path: str,
|
| 256 |
+
*,
|
| 257 |
+
encoding: str = "utf-8",
|
| 258 |
+
errors: str = "replace",
|
| 259 |
+
) -> str:
|
| 260 |
+
"""
|
| 261 |
+
Load a text object from S3.
|
| 262 |
+
|
| 263 |
+
This is primarily for transcript corpora (e.g. .txt) that need to be
|
| 264 |
+
converted into ChatML examples.
|
| 265 |
+
"""
|
| 266 |
+
data = self.load_bytes(s3_path)
|
| 267 |
+
return data.decode(encoding, errors=errors)
|
| 268 |
+
|
| 269 |
+
def _parse_jsonl_line(self, line: bytes) -> dict[str, Any] | None:
|
| 270 |
+
"""
|
| 271 |
+
Parse a single JSONL line with robust error handling.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
line: Raw bytes of a JSONL line
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Parsed JSON object or None if parsing failed
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
return json.loads(line.decode("utf-8"))
|
| 281 |
+
except UnicodeDecodeError:
|
| 282 |
+
try:
|
| 283 |
+
return json.loads(line.decode("utf-8", errors="replace"))
|
| 284 |
+
except json.JSONDecodeError as e:
|
| 285 |
+
logger.warning(f"Failed to parse JSONL line: {e}")
|
| 286 |
+
except json.JSONDecodeError as e:
|
| 287 |
+
logger.warning(f"Failed to parse JSONL line: {e}")
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
def _stream_with_iter_lines(self, body) -> Iterator[dict[str, Any]]:
|
| 291 |
+
"""
|
| 292 |
+
Stream JSONL using iter_lines() method.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
body: S3 response body with iter_lines capability
|
| 296 |
+
|
| 297 |
+
Yields:
|
| 298 |
+
Parsed JSON objects
|
| 299 |
+
"""
|
| 300 |
+
for raw_line in body.iter_lines():
|
| 301 |
+
if not raw_line:
|
| 302 |
+
continue
|
| 303 |
+
parsed = self._parse_jsonl_line(raw_line)
|
| 304 |
+
if parsed is not None:
|
| 305 |
+
yield parsed
|
| 306 |
+
|
| 307 |
+
def _stream_with_manual_buffering(self, body) -> Iterator[dict[str, Any]]:
|
| 308 |
+
"""
|
| 309 |
+
Stream JSONL using manual buffering as fallback.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
body: S3 response body
|
| 313 |
+
|
| 314 |
+
Yields:
|
| 315 |
+
Parsed JSON objects
|
| 316 |
+
"""
|
| 317 |
+
buffer = BytesIO()
|
| 318 |
+
for chunk in body.iter_chunks(chunk_size=8192):
|
| 319 |
+
buffer.write(chunk)
|
| 320 |
+
while True:
|
| 321 |
+
buffer.seek(0)
|
| 322 |
+
line = buffer.readline()
|
| 323 |
+
if not line:
|
| 324 |
+
buffer = BytesIO()
|
| 325 |
+
break
|
| 326 |
+
if not line.endswith(b"\n"):
|
| 327 |
+
# Keep incomplete tail in buffer
|
| 328 |
+
rest = buffer.read()
|
| 329 |
+
buffer = BytesIO(line + rest)
|
| 330 |
+
break
|
| 331 |
+
|
| 332 |
+
parsed = self._parse_jsonl_line(line)
|
| 333 |
+
if parsed is not None:
|
| 334 |
+
yield parsed
|
| 335 |
+
|
| 336 |
+
rest = buffer.read()
|
| 337 |
+
buffer = BytesIO(rest)
|
| 338 |
+
|
| 339 |
+
def stream_jsonl(self, s3_path: str) -> Iterator[dict[str, Any]]:
|
| 340 |
+
"""
|
| 341 |
+
Stream JSONL dataset from S3 (memory-efficient for large files).
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
s3_path: S3 path (s3://bucket/key or just key)
|
| 345 |
+
|
| 346 |
+
Yields:
|
| 347 |
+
Parsed JSON objects (one per line)
|
| 348 |
+
"""
|
| 349 |
+
bucket, key = self._parse_s3_path(s3_path)
|
| 350 |
+
|
| 351 |
+
logger.info(f"Streaming JSONL from S3: s3://{bucket}/{key}")
|
| 352 |
+
try:
|
| 353 |
+
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
| 354 |
+
body = response["Body"]
|
| 355 |
+
|
| 356 |
+
with contextlib.closing(body):
|
| 357 |
+
# Prefer iter_lines() which handles chunk boundaries robustly
|
| 358 |
+
iter_lines = getattr(body, "iter_lines", None)
|
| 359 |
+
if callable(iter_lines):
|
| 360 |
+
yield from self._stream_with_iter_lines(body)
|
| 361 |
+
else:
|
| 362 |
+
# Fallback to manual buffering
|
| 363 |
+
yield from self._stream_with_manual_buffering(body)
|
| 364 |
+
|
| 365 |
+
except ClientError as e:
|
| 366 |
+
if e.response["Error"]["Code"] == "NoSuchKey":
|
| 367 |
+
raise FileNotFoundError(
|
| 368 |
+
f"Dataset not found in S3: s3://{bucket}/{key}"
|
| 369 |
+
) from e
|
| 370 |
+
raise
|
| 371 |
+
|
| 372 |
+
def list_datasets(self, prefix: str = "gdrive/processed/") -> list[str]:
|
| 373 |
+
"""
|
| 374 |
+
List available datasets in S3.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
prefix: S3 prefix to search (default: gdrive/processed/)
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
List of S3 paths
|
| 381 |
+
"""
|
| 382 |
+
logger.info(f"Listing datasets with prefix: {prefix}")
|
| 383 |
+
datasets: list[str] = []
|
| 384 |
+
|
| 385 |
+
try:
|
| 386 |
+
paginator = self.s3_client.get_paginator("list_objects_v2")
|
| 387 |
+
pages = paginator.paginate(Bucket=self.bucket, Prefix=prefix)
|
| 388 |
+
|
| 389 |
+
for page in pages:
|
| 390 |
+
if "Contents" in page:
|
| 391 |
+
datasets.extend(
|
| 392 |
+
f"s3://{self.bucket}/{obj['Key']}"
|
| 393 |
+
for obj in page["Contents"]
|
| 394 |
+
if obj["Key"].endswith((".json", ".jsonl"))
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
except ClientError:
|
| 398 |
+
logger.exception("Failed to list S3 objects")
|
| 399 |
+
raise
|
| 400 |
+
return datasets
|
| 401 |
+
|
| 402 |
+
def download_file(self, s3_path: str, local_path: Path | str) -> None:
|
| 403 |
+
"""Download a file from S3 to local path"""
|
| 404 |
+
try:
|
| 405 |
+
bucket, key = self._parse_s3_path(s3_path)
|
| 406 |
+
logger.info(f"Downloading s3://{bucket}/{key} to {local_path}")
|
| 407 |
+
self.s3_client.download_file(bucket, key, str(local_path))
|
| 408 |
+
except Exception:
|
| 409 |
+
logger.exception(f"Failed to download {s3_path} to {local_path}")
|
| 410 |
+
raise
|
| 411 |
+
|
| 412 |
+
def upload_file(self, local_path: Path | str, s3_key: str) -> None:
|
| 413 |
+
"""Upload a local file to S3"""
|
| 414 |
+
try:
|
| 415 |
+
if not isinstance(local_path, Path):
|
| 416 |
+
local_path = Path(local_path)
|
| 417 |
+
|
| 418 |
+
bucket, key = self._parse_s3_path(s3_key)
|
| 419 |
+
|
| 420 |
+
logger.info(f"Uploading {local_path} to s3://{bucket}/{key}")
|
| 421 |
+
self.s3_client.upload_file(str(local_path), bucket, key)
|
| 422 |
+
except Exception:
|
| 423 |
+
logger.exception(f"Failed to upload {local_path} to {s3_key}")
|
| 424 |
+
raise
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def get_s3_dataset_path(
|
| 428 |
+
dataset_name: str,
|
| 429 |
+
category: str | None = None,
|
| 430 |
+
bucket: str = "pixel-data",
|
| 431 |
+
prefer_processed: bool = True,
|
| 432 |
+
) -> str:
|
| 433 |
+
"""
|
| 434 |
+
Get S3 path for dataset - S3 is canonical training data location.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
dataset_name: Name of the dataset file
|
| 438 |
+
category: Optional category (cot_reasoning, professional_therapeutic, etc.)
|
| 439 |
+
bucket: S3 bucket name
|
| 440 |
+
prefer_processed: Prefer processed/canonical structure over raw
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
S3 path (s3://bucket/path)
|
| 444 |
+
"""
|
| 445 |
+
loader = S3DatasetLoader(bucket=bucket)
|
| 446 |
+
|
| 447 |
+
# Try canonical processed structure first
|
| 448 |
+
if category and prefer_processed:
|
| 449 |
+
path = f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}"
|
| 450 |
+
if loader.object_exists(path):
|
| 451 |
+
return path
|
| 452 |
+
|
| 453 |
+
# Fallback to raw structure
|
| 454 |
+
if prefer_processed:
|
| 455 |
+
path = f"s3://{bucket}/gdrive/raw/{dataset_name}"
|
| 456 |
+
if loader.object_exists(path):
|
| 457 |
+
return path
|
| 458 |
+
|
| 459 |
+
# Fallback to acquired
|
| 460 |
+
path = f"s3://{bucket}/acquired/{dataset_name}"
|
| 461 |
+
if loader.object_exists(path):
|
| 462 |
+
return path
|
| 463 |
+
|
| 464 |
+
# If category provided, construct path even if doesn't exist yet
|
| 465 |
+
if category:
|
| 466 |
+
return f"s3://{bucket}/gdrive/processed/{category}/{dataset_name}"
|
| 467 |
+
|
| 468 |
+
return f"s3://{bucket}/gdrive/raw/{dataset_name}"
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def load_dataset_from_s3(
|
| 472 |
+
dataset_name: str,
|
| 473 |
+
category: str | None = None,
|
| 474 |
+
cache_local: Path | None = None,
|
| 475 |
+
bucket: str = "pixel-data",
|
| 476 |
+
) -> dict[str, Any]:
|
| 477 |
+
"""
|
| 478 |
+
Load dataset from S3 with automatic path resolution.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
dataset_name: Name of the dataset file
|
| 482 |
+
category: Optional category for canonical structure
|
| 483 |
+
cache_local: Optional local cache path
|
| 484 |
+
bucket: S3 bucket name
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Dataset data
|
| 488 |
+
"""
|
| 489 |
+
loader = S3DatasetLoader(bucket=bucket)
|
| 490 |
+
s3_path = get_s3_dataset_path(dataset_name, category, bucket)
|
| 491 |
+
|
| 492 |
+
if dataset_name.endswith(".jsonl"):
|
| 493 |
+
# For JSONL, convert to list
|
| 494 |
+
return {"conversations": list(loader.stream_jsonl(s3_path))}
|
| 495 |
+
return loader.load_json(s3_path, cache_local)
|
utils/subtitle_processor.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SubtitleProcessor:
|
| 6 |
+
"""Utility for cleaning and formatting YouTube VTT subtitles."""
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def clean_vtt(vtt_content: str) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Clean VTT content by removing timestamps, tags, and duplicates.
|
| 12 |
+
YouTube automatic captions often repeat lines with incremental words.
|
| 13 |
+
"""
|
| 14 |
+
# Remove header
|
| 15 |
+
lines = vtt_content.split("\n")
|
| 16 |
+
if lines and lines[0].startswith("WEBVTT"):
|
| 17 |
+
lines = lines[1:]
|
| 18 |
+
|
| 19 |
+
# Remove metadata lines (Kind:, Language:, etc)
|
| 20 |
+
lines = [
|
| 21 |
+
line
|
| 22 |
+
for line in lines
|
| 23 |
+
if not any(
|
| 24 |
+
line.startswith(prefix)
|
| 25 |
+
for prefix in ["Kind:", "Language:", "align:", "position:"]
|
| 26 |
+
)
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Remove timestamp lines and tags
|
| 30 |
+
# Pattern for 00:00:00.000 --> 00:00:00.000
|
| 31 |
+
timestamp_pattern = re.compile(
|
| 32 |
+
r"\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}.*"
|
| 33 |
+
)
|
| 34 |
+
# Pattern for <00:00:00.000><c> etc
|
| 35 |
+
tag_pattern = re.compile(r"<[^>]+>")
|
| 36 |
+
|
| 37 |
+
cleaned_paragraphs = []
|
| 38 |
+
current_text = []
|
| 39 |
+
|
| 40 |
+
seen_lines = set()
|
| 41 |
+
|
| 42 |
+
for line in lines:
|
| 43 |
+
line = line.strip()
|
| 44 |
+
if not line:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
if timestamp_pattern.match(line):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
# Clean tags
|
| 51 |
+
cleaned_line = tag_pattern.sub("", line).strip()
|
| 52 |
+
|
| 53 |
+
if not cleaned_line:
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
# YouTube auto-subs repeat text heavily.
|
| 57 |
+
# We want to keep unique sentences/segments.
|
| 58 |
+
if cleaned_line in seen_lines:
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
seen_lines.add(cleaned_line)
|
| 62 |
+
current_text.append(cleaned_line)
|
| 63 |
+
|
| 64 |
+
# Merge lines and remove redundant parts of sentences
|
| 65 |
+
full_text = " ".join(current_text)
|
| 66 |
+
|
| 67 |
+
# Simple cleanup of redundant repeated segments (YouTube specific)
|
| 68 |
+
# e.g. "Hello world Hello world there" -> "Hello world there"
|
| 69 |
+
# This is a bit complex to do perfectly without NLP, but we can do some basics.
|
| 70 |
+
|
| 71 |
+
return full_text
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def format_as_markdown(text: str, metadata: dict) -> str:
|
| 75 |
+
"""Format the cleaned text as a structured Markdown file."""
|
| 76 |
+
title = metadata.get("title", "Unknown Title")
|
| 77 |
+
channel = metadata.get("channel", "Unknown Channel")
|
| 78 |
+
video_url = metadata.get("url", "")
|
| 79 |
+
date = metadata.get("date", "")
|
| 80 |
+
|
| 81 |
+
md = f"# {title}\n\n"
|
| 82 |
+
md += f"**Channel:** {channel}\n"
|
| 83 |
+
md += f"**Source:** {video_url}\n"
|
| 84 |
+
md += f"**Date:** {date}\n\n"
|
| 85 |
+
md += "## Transcript\n\n"
|
| 86 |
+
|
| 87 |
+
# Split into paragraphs of roughly 5-7 sentences
|
| 88 |
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
| 89 |
+
paragraphs = []
|
| 90 |
+
for i in range(0, len(sentences), 6):
|
| 91 |
+
paragraphs.append(" ".join(sentences[i : i + 6]))
|
| 92 |
+
|
| 93 |
+
md += "\n\n".join(paragraphs)
|
| 94 |
+
md += "\n"
|
| 95 |
+
|
| 96 |
+
return md
|
utils/transcript_corrector.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
# Configure logger
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TranscriptCorrector:
|
| 12 |
+
"""
|
| 13 |
+
Utility class for correcting transcripts using a multi-pass approach:
|
| 14 |
+
1. Therapeutic Terminology Validation
|
| 15 |
+
2. LLM-based Contextual Correction (Mocked for now)
|
| 16 |
+
3. Structural Alignment (Basic regex cleanup)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config_path: str = "ai/config/therapeutic_terminology.json"):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the TranscriptCorrector with terminology configuration.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
config_path: Path to the JSON configuration file containing
|
| 25 |
+
therapeutic terms.
|
| 26 |
+
"""
|
| 27 |
+
self.config_path = Path(config_path)
|
| 28 |
+
self.terms: Dict[str, Any] = self._load_terminology()
|
| 29 |
+
|
| 30 |
+
def _load_terminology(self) -> Dict[str, Any]:
|
| 31 |
+
"""Load therapeutic terminology from JSON config."""
|
| 32 |
+
try:
|
| 33 |
+
# Handle relative paths from project root if needed
|
| 34 |
+
if not self.config_path.exists():
|
| 35 |
+
# Try relative to the current file location
|
| 36 |
+
# structure is usually ai/utils/transcript_corrector.py
|
| 37 |
+
# config is at ai/config/therapeutic_terminology.json
|
| 38 |
+
# so we go up 2 levels
|
| 39 |
+
base_path = Path(__file__).parent.parent
|
| 40 |
+
alt_path = base_path / "config" / "therapeutic_terminology.json"
|
| 41 |
+
|
| 42 |
+
if alt_path.exists():
|
| 43 |
+
self.config_path = alt_path
|
| 44 |
+
else:
|
| 45 |
+
logger.warning(
|
| 46 |
+
f"Terminology config not found at {self.config_path} or "
|
| 47 |
+
f"{alt_path}. Using empty config."
|
| 48 |
+
)
|
| 49 |
+
return {
|
| 50 |
+
"cptsd_terms": [],
|
| 51 |
+
"medical_terms": [],
|
| 52 |
+
"common_misinterpretations": {},
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
| 56 |
+
return json.load(f)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Failed to load terminology config: {e}")
|
| 59 |
+
return {
|
| 60 |
+
"cptsd_terms": [],
|
| 61 |
+
"medical_terms": [],
|
| 62 |
+
"common_misinterpretations": {},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def correct_transcript(self, text: str, context: str = "therapy_session") -> str:
|
| 66 |
+
"""
|
| 67 |
+
Main entry point for transcript correction.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
text: Single string containing the transcript text to correct.
|
| 71 |
+
context: Context hint for LLM correction.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Corrected transcript text.
|
| 75 |
+
"""
|
| 76 |
+
if not text or not text.strip():
|
| 77 |
+
return ""
|
| 78 |
+
|
| 79 |
+
# Pass 1: Basic Structural Cleanup
|
| 80 |
+
text = self._clean_structure(text)
|
| 81 |
+
|
| 82 |
+
# Pass 2: Terminology Replacement
|
| 83 |
+
text = self._apply_terminology_fixes(text)
|
| 84 |
+
|
| 85 |
+
# Pass 3: LLM Contextual Correction (Mocked)
|
| 86 |
+
text = self._llm_contextual_correction(text, context)
|
| 87 |
+
|
| 88 |
+
return text
|
| 89 |
+
|
| 90 |
+
def _clean_structure(self, text: str) -> str:
|
| 91 |
+
"""Remove filler words and normalize whitespace."""
|
| 92 |
+
# Common filler words in speech, optionally followed by a comma
|
| 93 |
+
fillers = r"\b(um|uh|err|ah|like|you know|I mean)\b,?\s*"
|
| 94 |
+
|
| 95 |
+
# Remove fillers (case-insensitive)
|
| 96 |
+
cleaned = re.sub(fillers, "", text, flags=re.IGNORECASE)
|
| 97 |
+
|
| 98 |
+
# Normalize whitespace (replace multiple spaces with single space)
|
| 99 |
+
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
| 100 |
+
|
| 101 |
+
return cleaned
|
| 102 |
+
|
| 103 |
+
def _apply_terminology_fixes(self, text: str) -> str:
|
| 104 |
+
"""Apply deterministic terminology fixes from config."""
|
| 105 |
+
misinterpretations = self.terms.get("common_misinterpretations", {})
|
| 106 |
+
|
| 107 |
+
for bad_term, good_term in misinterpretations.items():
|
| 108 |
+
# Use word boundaries to match whole words/phrases ignoring case
|
| 109 |
+
pattern = re.compile(re.escape(bad_term), re.IGNORECASE)
|
| 110 |
+
text = pattern.sub(good_term, text)
|
| 111 |
+
|
| 112 |
+
return text
|
| 113 |
+
|
| 114 |
+
def _llm_contextual_correction(self, text: str, context: str) -> str:
|
| 115 |
+
"""
|
| 116 |
+
Mock function for GPT-4 based correction.
|
| 117 |
+
In the future, this will call the LLM service to fix grammar and nuances.
|
| 118 |
+
"""
|
| 119 |
+
# TODO: Implement actual LLM call via external service or local model
|
| 120 |
+
# For now, we just log that we would allow the LLM to process this
|
| 121 |
+
# and return the text as is (or maybe apply a dummy transformation
|
| 122 |
+
# for testing if needed)
|
| 123 |
+
|
| 124 |
+
# Simulating a check for critical CPTSD terms that might be missed
|
| 125 |
+
# If we had an LLM, we'd ask it: "Correct this transcript keeping CPTSD context
|
| 126 |
+
# in mind."
|
| 127 |
+
|
| 128 |
+
return text
|
| 129 |
+
|
| 130 |
+
def validate_term_coverage(self, text: str) -> Dict[str, float]:
|
| 131 |
+
"""
|
| 132 |
+
Calculate metrics on how well the transcript effectively uses domain
|
| 133 |
+
terminology. Useful for validation pass.
|
| 134 |
+
"""
|
| 135 |
+
cptsd_terms = {t.lower() for t in self.terms.get("cptsd_terms", [])}
|
| 136 |
+
medical_terms = {t.lower() for t in self.terms.get("medical_terms", [])}
|
| 137 |
+
|
| 138 |
+
text_lower = text.lower()
|
| 139 |
+
|
| 140 |
+
found_cptsd = sum(term in text_lower for term in cptsd_terms)
|
| 141 |
+
found_medical = sum(term in text_lower for term in medical_terms)
|
| 142 |
+
|
| 143 |
+
total_domain_terms = len(cptsd_terms) + len(medical_terms)
|
| 144 |
+
found_total = found_cptsd + found_medical
|
| 145 |
+
|
| 146 |
+
# This is a naive metric, just for basic validation
|
| 147 |
+
coverage_score = (
|
| 148 |
+
found_total / total_domain_terms if total_domain_terms > 0 else 0.0
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"cptsd_term_count": found_cptsd,
|
| 153 |
+
"medical_term_count": found_medical,
|
| 154 |
+
"domain_coverage_score": round(coverage_score, 4),
|
| 155 |
+
}
|