Spaces:
Running
Running
File size: 4,638 Bytes
7011b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import sys
import os
import subprocess
import time
import logging
import requests
import torch
from pathlib import Path
from datetime import timedelta
# --- 1. LOGGING SETUP ---
# Identify Node Rank for logging clarity
NODE_ID = os.environ.get("SLURM_PROCID", "0")
logging.basicConfig(
level=logging.INFO,
format=f'%(asctime)s - [Node {NODE_ID}] - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(f"logs/node_{NODE_ID}_transform.log")
]
)
logger = logging.getLogger(__name__)
def main():
t_start = time.perf_counter()
logger.info(f"🚀 Starting Transformation Pipeline on Node {NODE_ID}")
# --- 2. ENVIRONMENT & PATHS ---
SCRATCH = Path(os.environ.get("SCRATCH", "/tmp"))
INPUT_PDFS_DIR = SCRATCH / "mshauri-fedha/data/knbs/pdfs"
OUTPUT_DIR = SCRATCH / "mshauri-fedha/data/knbs/marker-output"
OLLAMA_HOME = SCRATCH / "ollama_core"
OLLAMA_BIN = OLLAMA_HOME / "bin/ollama"
OLLAMA_HOST = "http://localhost:11434"
# Important: Ensure the current directory is in sys.path for 'extract' import
if os.getcwd() not in sys.path:
sys.path.append(os.getcwd())
try:
from extract import MarkerFolderProcessor, configure_parallelism
except ImportError as e:
logger.error(f"Could not import extract.py from {os.getcwd()}")
raise e
# --- 3. DYNAMIC PARALLELISM & OLLAMA CONFIG ---
# Calculates workers based on node hardware (GH200 96GB)
total_slots, workers_per_gpu, num_gpus = configure_parallelism()
# Clean up any zombie servers on this node
subprocess.run(["pkill", "-f", "ollama serve"], stderr=subprocess.DEVNULL)
time.sleep(5)
# Set server environment variables
server_env = os.environ.copy()
server_env["OLLAMA_NUM_PARALLEL"] = str(total_slots)
server_env["OLLAMA_MAX_LOADED_MODELS"] = "1"
server_env["OLLAMA_MAX_QUEUE"] = "2048"
server_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
logger.info(f"⏳ Launching Ollama Server with {total_slots} slots...")
subprocess.Popen(
[str(OLLAMA_BIN), "serve"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
env=server_env
)
# Heartbeat check
for i in range(60):
try:
if requests.get(OLLAMA_HOST).status_code == 200:
logger.info(" Ollama Server is online.")
break
except:
time.sleep(1)
else:
raise RuntimeError(" Ollama server heartbeat failed.")
# --- 4. MODEL SETUP ---
BASE_MODEL = "qwen2.5:7b"
CUSTOM_MODEL_NAME = "qwen2.5-7b-16k"
logger.info(f" Pulling {BASE_MODEL}...")
subprocess.run([str(OLLAMA_BIN), "pull", BASE_MODEL], check=True, capture_output=True)
logger.info(f" Creating custom 16k context model...")
modelfile_path = Path(f"Modelfile_node_{NODE_ID}")
modelfile_path.write_text(f"FROM {BASE_MODEL}\nPARAMETER num_ctx 16384")
subprocess.run([str(OLLAMA_BIN), "create", CUSTOM_MODEL_NAME, "-f", str(modelfile_path)], check=True, capture_output=True)
# --- 5. AUTOMATED DATA PARTITIONING ---
# Get all PDFs and sort them for deterministic behavior
all_pdfs = sorted(list(INPUT_PDFS_DIR.glob("*.pdf")))
total_nodes = int(os.environ.get("SLURM_NTASKS", 1))
node_rank = int(NODE_ID)
# Each node takes every Nth file (Node 0 takes index 0, 2, 4... Node 1 takes 1, 3, 5...)
my_pdfs = all_pdfs[node_rank::total_nodes]
my_pdf_strs = [str(p) for p in my_pdfs]
logger.info(f" Data Partitioning: Node {node_rank}/{total_nodes} handling {len(my_pdfs)} files.")
# --- 6. EXECUTION ---
os.chdir(SCRATCH)
processor = MarkerFolderProcessor(
output_dir=OUTPUT_DIR,
ollama_url=OLLAMA_HOST,
ollama_model=CUSTOM_MODEL_NAME,
batch_multiplier=4,
workers_per_gpu=workers_per_gpu,
num_gpus=num_gpus
)
logger.info(f"🚀 Processing PDFs...")
# Using the 'subset' parameter in process_folder (ensure extract.py supports this)
processor.process_folder(INPUT_PDFS_DIR, batch_size=5, subset=my_pdf_strs)
# --- 7. CLEANUP & TIMING ---
t_end = time.perf_counter()
duration = timedelta(seconds=t_end - t_start)
logger.info(" Transformation process finished.")
logger.info(f"⏱️ Total Duration for Node {NODE_ID}: {duration}")
# Shutdown server
subprocess.run(["pkill", "-f", "ollama serve"], stderr=subprocess.DEVNULL)
if modelfile_path.exists(): modelfile_path.unlink()
if __name__ == "__main__":
main() |