#!/usr/bin/env python3 """ Codette LoRA Adapter Training Script Hardware-adaptive version supporting: CUDA (NVIDIA) XPU (Intel Arc) MPS (Apple) CPU fallback """ import argparse import json import logging import os import sys import time from datetime import datetime from pathlib import Path import yaml from datasets import Dataset os.environ["TOKENIZERS_PARALLELISM"] = "false" # Ensure Intel SYCL runtime DLLs are discoverable for XPU support _intel_bin = os.path.join(sys.prefix, "Lib", "site-packages", "Library", "bin") if os.path.isdir(_intel_bin) and _intel_bin not in os.environ.get("PATH", ""): os.environ["PATH"] = _intel_bin + os.pathsep + os.environ.get("PATH", "") import torch # ------------------------------------------------------------ # LOGGING # ------------------------------------------------------------ def setup_logging(output_dir: str, adapter_name: str): log_dir = Path(output_dir) / "logs" log_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = log_dir / f"train_{adapter_name}_{timestamp}.log" logger = logging.getLogger(f"codette.train.{adapter_name}") logger.setLevel(logging.DEBUG) logger.handlers.clear() fh = logging.FileHandler(log_file) fh.setLevel(logging.DEBUG) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s | %(levelname)-8s | %(message)s", "%H:%M:%S" ) fh.setFormatter(formatter) ch.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(ch) return logger # ------------------------------------------------------------ # DEVICE DETECTION # ------------------------------------------------------------ def detect_vulkan_available(): """Check if Vulkan compute is available (for non-PyTorch acceleration).""" try: import sys from pathlib import Path inference_dir = str(Path(__file__).parent.parent / "inference") if inference_dir not in sys.path: sys.path.insert(0, inference_dir) from vulkan_compute import is_vulkan_available return is_vulkan_available() except Exception: return False def detect_device(): if torch.cuda.is_available(): return "cuda" if hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" if detect_vulkan_available(): return "vulkan" return "cpu" # ------------------------------------------------------------ # CONFIG # ------------------------------------------------------------ def load_training_config(path=None): if path is None: path = Path(__file__).parent / "configs" / "default_training.yaml" with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) # ------------------------------------------------------------ # DATASET # ------------------------------------------------------------ def load_jsonl_dataset(dataset_path): records = [] with open(dataset_path, "r", encoding="utf-8") as f: for line in f: obj = json.loads(line) if "messages" not in obj: continue records.append(obj) return Dataset.from_list(records) def format_chat_messages(example, tokenizer): text = tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False, ) return {"text": text} # ------------------------------------------------------------ # MODEL LOADING # ------------------------------------------------------------ def create_model_and_tokenizer(model_name, device, logger): from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, ) logger.info(f"Loading tokenizer: {model_name}") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model_kwargs = { "trust_remote_code": True, "use_cache": False, } # ---------------- Intel XPU — streaming file I/O loading ---------------- # Arc 140V: 8GB VRAM (too small for 16GB bf16 model), BnB is CUDA-only. # from_pretrained/load_checkpoint_and_dispatch/safe_open all use mmap → OOM. # Fix: read safetensors binary format with plain open()+read(), no mmap. if device == "xpu": logger.info("Intel Arc — streaming CPU load (no mmap, minimal peak memory)") import ctypes import gc import struct as _struct from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from huggingface_hub import snapshot_download from transformers import AutoConfig checkpoint_dir = snapshot_download(model_name) logger.info(f"Checkpoint: {checkpoint_dir}") gc.collect() model_config = AutoConfig.from_pretrained( model_name, trust_remote_code=True ) with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, trust_remote_code=True ) _dt = { "BF16": torch.bfloat16, "F16": torch.float16, "F32": torch.float32, "F64": torch.float64, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool, } shard_files = sorted(Path(checkpoint_dir).glob("*.safetensors")) logger.info(f"Loading {len(shard_files)} shards via streaming I/O") for i, shard_file in enumerate(shard_files): logger.info(f" Shard {i+1}/{len(shard_files)}: {shard_file.name}") with open(shard_file, "rb") as fp: header_size = _struct.unpack("