#!/usr/bin/env python3 """ Codette Shared-Model Batch Adapter Training -------------------------------------------- Loads the base model ONCE and trains multiple LoRA adapters sequentially without reloading the 8B model. Major benefits -------------- * Eliminates repeated model loads * Prevents GPU memory fragmentation * Speeds multi-adapter training * More stable on 8GB GPUs """ import argparse import json import logging import sys import time from datetime import datetime from pathlib import Path import os import yaml import torch os.environ["TOKENIZERS_PARALLELISM"] = "false" # --------------------------------------------------------- # Logging # --------------------------------------------------------- def setup_logging(): log_dir = Path("logs") log_dir.mkdir(exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = log_dir / f"shared_training_{ts}.log" logger = logging.getLogger("codette.shared_train") logger.setLevel(logging.DEBUG) logger.handlers.clear() fh = logging.FileHandler(log_file) fh.setLevel(logging.DEBUG) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) fmt = logging.Formatter( "%(asctime)s | %(levelname)-8s | %(message)s", "%H:%M:%S" ) fh.setFormatter(fmt) ch.setFormatter(fmt) logger.addHandler(fh) logger.addHandler(ch) return logger # --------------------------------------------------------- # Device Detection # --------------------------------------------------------- def detect_device(): if torch.cuda.is_available(): return "cuda" if hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" # --------------------------------------------------------- # Config Loaders # --------------------------------------------------------- def load_adapter_registry(path): with open(path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) return cfg["adapters"] def load_training_defaults(path=None): if path is None: path = Path("configs/default_training.yaml") with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) # --------------------------------------------------------- # Dataset Loader # --------------------------------------------------------- def load_jsonl_dataset(path): from datasets import Dataset rows = [] with open(path, "r", encoding="utf-8") as f: for line in f: obj = json.loads(line) if "messages" in obj: rows.append(obj) return Dataset.from_list(rows) def format_chat_messages(example, tokenizer): text = tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False, ) return {"text": text} # --------------------------------------------------------- # Base Model Loader # --------------------------------------------------------- def load_base_model(model_name, device, logger): from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # --- XPU: streaming file I/O loading (no mmap, avoids OOM) --- if device == "xpu": logger.info("Intel Arc — streaming CPU load (no mmap, minimal peak memory)") import ctypes import gc import struct as _struct from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from huggingface_hub import snapshot_download from transformers import AutoConfig checkpoint_dir = snapshot_download(model_name) logger.info(f"Checkpoint: {checkpoint_dir}") gc.collect() model_config = AutoConfig.from_pretrained( model_name, trust_remote_code=True ) with init_empty_weights(): model = AutoModelForCausalLM.from_config( model_config, trust_remote_code=True ) _dt = { "BF16": torch.bfloat16, "F16": torch.float16, "F32": torch.float32, "F64": torch.float64, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool, } shard_files = sorted(Path(checkpoint_dir).glob("*.safetensors")) logger.info(f"Loading {len(shard_files)} shards via streaming I/O") for i, shard_file in enumerate(shard_files): logger.info(f" Shard {i+1}/{len(shard_files)}: {shard_file.name}") with open(shard_file, "rb") as fp: header_size = _struct.unpack("