Spaces:
Running
Running
| import sys | |
| import os | |
| import subprocess | |
| import time | |
| import logging | |
| import requests | |
| import torch | |
| from pathlib import Path | |
| from datetime import timedelta | |
| # --- 1. LOGGING SETUP --- | |
| # Identify Node Rank for logging clarity | |
| NODE_ID = os.environ.get("SLURM_PROCID", "0") | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format=f'%(asctime)s - [Node {NODE_ID}] - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler(f"logs/node_{NODE_ID}_transform.log") | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| t_start = time.perf_counter() | |
| logger.info(f"🚀 Starting Transformation Pipeline on Node {NODE_ID}") | |
| # --- 2. ENVIRONMENT & PATHS --- | |
| SCRATCH = Path(os.environ.get("SCRATCH", "/tmp")) | |
| INPUT_PDFS_DIR = SCRATCH / "mshauri-fedha/data/knbs/pdfs" | |
| OUTPUT_DIR = SCRATCH / "mshauri-fedha/data/knbs/marker-output" | |
| OLLAMA_HOME = SCRATCH / "ollama_core" | |
| OLLAMA_BIN = OLLAMA_HOME / "bin/ollama" | |
| OLLAMA_HOST = "http://localhost:11434" | |
| # Important: Ensure the current directory is in sys.path for 'extract' import | |
| if os.getcwd() not in sys.path: | |
| sys.path.append(os.getcwd()) | |
| try: | |
| from extract import MarkerFolderProcessor, configure_parallelism | |
| except ImportError as e: | |
| logger.error(f"Could not import extract.py from {os.getcwd()}") | |
| raise e | |
| # --- 3. DYNAMIC PARALLELISM & OLLAMA CONFIG --- | |
| # Calculates workers based on node hardware (GH200 96GB) | |
| total_slots, workers_per_gpu, num_gpus = configure_parallelism() | |
| # Clean up any zombie servers on this node | |
| subprocess.run(["pkill", "-f", "ollama serve"], stderr=subprocess.DEVNULL) | |
| time.sleep(5) | |
| # Set server environment variables | |
| server_env = os.environ.copy() | |
| server_env["OLLAMA_NUM_PARALLEL"] = str(total_slots) | |
| server_env["OLLAMA_MAX_LOADED_MODELS"] = "1" | |
| server_env["OLLAMA_MAX_QUEUE"] = "2048" | |
| server_env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| logger.info(f"⏳ Launching Ollama Server with {total_slots} slots...") | |
| subprocess.Popen( | |
| [str(OLLAMA_BIN), "serve"], | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| env=server_env | |
| ) | |
| # Heartbeat check | |
| for i in range(60): | |
| try: | |
| if requests.get(OLLAMA_HOST).status_code == 200: | |
| logger.info(" Ollama Server is online.") | |
| break | |
| except: | |
| time.sleep(1) | |
| else: | |
| raise RuntimeError(" Ollama server heartbeat failed.") | |
| # --- 4. MODEL SETUP --- | |
| BASE_MODEL = "qwen2.5:7b" | |
| CUSTOM_MODEL_NAME = "qwen2.5-7b-16k" | |
| logger.info(f" Pulling {BASE_MODEL}...") | |
| subprocess.run([str(OLLAMA_BIN), "pull", BASE_MODEL], check=True, capture_output=True) | |
| logger.info(f" Creating custom 16k context model...") | |
| modelfile_path = Path(f"Modelfile_node_{NODE_ID}") | |
| modelfile_path.write_text(f"FROM {BASE_MODEL}\nPARAMETER num_ctx 16384") | |
| subprocess.run([str(OLLAMA_BIN), "create", CUSTOM_MODEL_NAME, "-f", str(modelfile_path)], check=True, capture_output=True) | |
| # --- 5. AUTOMATED DATA PARTITIONING --- | |
| # Get all PDFs and sort them for deterministic behavior | |
| all_pdfs = sorted(list(INPUT_PDFS_DIR.glob("*.pdf"))) | |
| total_nodes = int(os.environ.get("SLURM_NTASKS", 1)) | |
| node_rank = int(NODE_ID) | |
| # Each node takes every Nth file (Node 0 takes index 0, 2, 4... Node 1 takes 1, 3, 5...) | |
| my_pdfs = all_pdfs[node_rank::total_nodes] | |
| my_pdf_strs = [str(p) for p in my_pdfs] | |
| logger.info(f" Data Partitioning: Node {node_rank}/{total_nodes} handling {len(my_pdfs)} files.") | |
| # --- 6. EXECUTION --- | |
| os.chdir(SCRATCH) | |
| processor = MarkerFolderProcessor( | |
| output_dir=OUTPUT_DIR, | |
| ollama_url=OLLAMA_HOST, | |
| ollama_model=CUSTOM_MODEL_NAME, | |
| batch_multiplier=4, | |
| workers_per_gpu=workers_per_gpu, | |
| num_gpus=num_gpus | |
| ) | |
| logger.info(f"🚀 Processing PDFs...") | |
| # Using the 'subset' parameter in process_folder (ensure extract.py supports this) | |
| processor.process_folder(INPUT_PDFS_DIR, batch_size=5, subset=my_pdf_strs) | |
| # --- 7. CLEANUP & TIMING --- | |
| t_end = time.perf_counter() | |
| duration = timedelta(seconds=t_end - t_start) | |
| logger.info(" Transformation process finished.") | |
| logger.info(f"⏱️ Total Duration for Node {NODE_ID}: {duration}") | |
| # Shutdown server | |
| subprocess.run(["pkill", "-f", "ollama serve"], stderr=subprocess.DEVNULL) | |
| if modelfile_path.exists(): modelfile_path.unlink() | |
| if __name__ == "__main__": | |
| main() |