|
|
|
|
|
""" |
|
|
NullAI Fine-tuning Module |
|
|
Implements apprentice model fine-tuning using master outputs (Alpaca format) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Any, Callable |
|
|
from datetime import datetime |
|
|
import asyncio |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class FineTuningManager: |
|
|
""" |
|
|
Manages fine-tuning of apprentice models using master outputs. |
|
|
Supports multiple backends: HuggingFace (PEFT/LoRA), Unsloth, MLX |
|
|
""" |
|
|
|
|
|
def __init__(self, training_data_dir: str = "training_data/master_outputs"): |
|
|
self.training_data_dir = Path(training_data_dir) |
|
|
self.checkpoints_dir = Path("training_data/checkpoints") |
|
|
self.checkpoints_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.current_training_state = { |
|
|
"is_training": False, |
|
|
"progress": 0.0, |
|
|
"current_epoch": 0, |
|
|
"total_epochs": 0, |
|
|
"loss": 0.0, |
|
|
"model_id": None, |
|
|
"start_time": None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def load_training_data(self, domain_id: Optional[str] = None) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Load training data from Alpaca-format JSONL files. |
|
|
|
|
|
Args: |
|
|
domain_id: Specific domain to load. If None, loads all domains. |
|
|
|
|
|
Returns: |
|
|
List of training examples in Alpaca format |
|
|
""" |
|
|
training_examples = [] |
|
|
|
|
|
if not self.training_data_dir.exists(): |
|
|
logger.warning(f"Training data directory not found: {self.training_data_dir}") |
|
|
return training_examples |
|
|
|
|
|
|
|
|
if domain_id: |
|
|
jsonl_files = [self.training_data_dir / f"master_outputs_{domain_id}.jsonl"] |
|
|
else: |
|
|
jsonl_files = list(self.training_data_dir.glob("master_outputs_*.jsonl")) |
|
|
|
|
|
for jsonl_file in jsonl_files: |
|
|
if not jsonl_file.exists(): |
|
|
logger.warning(f"Training data file not found: {jsonl_file}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Loading training data from: {jsonl_file}") |
|
|
with open(jsonl_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
example = json.loads(line.strip()) |
|
|
training_examples.append(example) |
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"Failed to parse JSON line in {jsonl_file}: {e}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Loaded {len(training_examples)} training examples") |
|
|
return training_examples |
|
|
|
|
|
def format_training_examples_for_model( |
|
|
self, |
|
|
training_examples: List[Dict[str, Any]], |
|
|
template: str = "alpaca" |
|
|
) -> List[str]: |
|
|
""" |
|
|
Format training examples into model-ready prompts. |
|
|
|
|
|
Args: |
|
|
training_examples: Raw Alpaca-format examples |
|
|
template: Prompt template format ("alpaca", "chatml", "llama3") |
|
|
|
|
|
Returns: |
|
|
List of formatted prompt strings |
|
|
""" |
|
|
formatted_prompts = [] |
|
|
|
|
|
for example in training_examples: |
|
|
instruction = example.get("instruction", "") |
|
|
input_text = example.get("input", "") |
|
|
output_text = example.get("output", "") |
|
|
|
|
|
if template == "alpaca": |
|
|
if input_text: |
|
|
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
|
|
|
|
### Instruction: |
|
|
{instruction} |
|
|
|
|
|
### Input: |
|
|
{input_text} |
|
|
|
|
|
### Response: |
|
|
{output_text}""" |
|
|
else: |
|
|
prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
|
|
### Instruction: |
|
|
{instruction} |
|
|
|
|
|
### Response: |
|
|
{output_text}""" |
|
|
|
|
|
elif template == "chatml": |
|
|
prompt = f"""<|im_start|>system |
|
|
{instruction}<|im_end|> |
|
|
<|im_start|>user |
|
|
{input_text}<|im_end|> |
|
|
<|im_start|>assistant |
|
|
{output_text}<|im_end|>""" |
|
|
|
|
|
elif template == "llama3": |
|
|
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
|
{instruction}<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
|
{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
{output_text}<|eot_id|>""" |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown template format: {template}") |
|
|
|
|
|
formatted_prompts.append(prompt) |
|
|
|
|
|
return formatted_prompts |
|
|
|
|
|
|
|
|
|
|
|
async def fine_tune_with_huggingface_peft( |
|
|
self, |
|
|
model_name: str, |
|
|
training_examples: List[Dict[str, Any]], |
|
|
output_dir: str, |
|
|
epochs: int = 3, |
|
|
learning_rate: float = 2e-4, |
|
|
batch_size: int = 4, |
|
|
gradient_accumulation_steps: int = 4, |
|
|
lora_r: int = 8, |
|
|
lora_alpha: int = 16, |
|
|
lora_dropout: float = 0.05, |
|
|
max_seq_length: int = 512, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Fine-tune model using HuggingFace Transformers + PEFT (LoRA). |
|
|
|
|
|
This is the recommended method for most models. |
|
|
Uses QLoRA (4-bit quantization) for memory efficiency. |
|
|
""" |
|
|
try: |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorForLanguageModeling |
|
|
) |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
from datasets import Dataset |
|
|
|
|
|
except ImportError as e: |
|
|
logger.error(f"Required libraries not installed: {e}") |
|
|
logger.error("Please install: pip install transformers peft datasets bitsandbytes accelerate") |
|
|
raise |
|
|
|
|
|
logger.info(f"Starting PEFT fine-tuning for model: {model_name}") |
|
|
self.current_training_state.update({ |
|
|
"is_training": True, |
|
|
"progress": 0.0, |
|
|
"current_epoch": 0, |
|
|
"total_epochs": epochs, |
|
|
"model_id": model_name, |
|
|
"start_time": datetime.utcnow().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
logger.info("Loading model with 4-bit quantization...") |
|
|
try: |
|
|
from transformers import BitsAndBytesConfig |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"4-bit quantization failed, falling back to float16: {e}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
|
|
|
logger.info("Configuring LoRA...") |
|
|
lora_config = LoraConfig( |
|
|
r=lora_r, |
|
|
lora_alpha=lora_alpha, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
lora_dropout=lora_dropout, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
logger.info("Formatting training data...") |
|
|
formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca") |
|
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
max_length=max_seq_length, |
|
|
padding="max_length" |
|
|
) |
|
|
|
|
|
dataset = Dataset.from_dict({"text": formatted_texts}) |
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset.column_names |
|
|
) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=batch_size, |
|
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
|
learning_rate=learning_rate, |
|
|
fp16=True, |
|
|
logging_steps=10, |
|
|
save_steps=100, |
|
|
save_total_limit=3, |
|
|
warmup_steps=50, |
|
|
optim="paged_adamw_8bit", |
|
|
report_to="none" |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False |
|
|
) |
|
|
|
|
|
|
|
|
class ProgressCallback: |
|
|
def __init__(self, manager, total_epochs, callback): |
|
|
self.manager = manager |
|
|
self.total_epochs = total_epochs |
|
|
self.callback = callback |
|
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
|
epoch = state.epoch |
|
|
loss = state.log_history[-1].get("loss", 0.0) if state.log_history else 0.0 |
|
|
|
|
|
self.manager.current_training_state.update({ |
|
|
"current_epoch": int(epoch), |
|
|
"progress": (epoch / self.total_epochs) * 100, |
|
|
"loss": loss |
|
|
}) |
|
|
|
|
|
if self.callback: |
|
|
asyncio.create_task(self.callback(self.manager.current_training_state)) |
|
|
|
|
|
from transformers import TrainerCallback |
|
|
|
|
|
class CustomCallback(TrainerCallback): |
|
|
def __init__(self, progress_cb): |
|
|
self.progress_cb = progress_cb |
|
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
|
self.progress_cb.on_epoch_end(args, state, control, **kwargs) |
|
|
|
|
|
progress_cb = ProgressCallback(self, epochs, progress_callback) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset, |
|
|
data_collator=data_collator, |
|
|
callbacks=[CustomCallback(progress_cb)] |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Starting training...") |
|
|
train_result = trainer.train() |
|
|
|
|
|
|
|
|
logger.info(f"Saving model to: {output_dir}") |
|
|
trainer.save_model(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
self.current_training_state.update({ |
|
|
"is_training": False, |
|
|
"progress": 100.0, |
|
|
"current_epoch": epochs |
|
|
}) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"output_dir": output_dir, |
|
|
"train_loss": train_result.training_loss, |
|
|
"metrics": train_result.metrics, |
|
|
"model_name": model_name, |
|
|
"lora_config": { |
|
|
"r": lora_r, |
|
|
"alpha": lora_alpha, |
|
|
"dropout": lora_dropout |
|
|
} |
|
|
} |
|
|
|
|
|
async def fine_tune_with_unsloth( |
|
|
self, |
|
|
model_name: str, |
|
|
training_examples: List[Dict[str, Any]], |
|
|
output_dir: str, |
|
|
epochs: int = 3, |
|
|
learning_rate: float = 2e-4, |
|
|
batch_size: int = 4, |
|
|
lora_r: int = 16, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Fine-tune model using Unsloth (fastest method, 2x faster than PEFT). |
|
|
|
|
|
Unsloth is optimized for speed and memory efficiency. |
|
|
Recommended for: Llama, Mistral, Qwen models |
|
|
""" |
|
|
try: |
|
|
from unsloth import FastLanguageModel |
|
|
from trl import SFTTrainer |
|
|
from transformers import TrainingArguments |
|
|
from datasets import Dataset |
|
|
except ImportError as e: |
|
|
logger.error(f"Unsloth not installed: {e}") |
|
|
logger.error("Please install: pip install unsloth") |
|
|
raise |
|
|
|
|
|
logger.info(f"Starting Unsloth fine-tuning for model: {model_name}") |
|
|
self.current_training_state.update({ |
|
|
"is_training": True, |
|
|
"progress": 0.0, |
|
|
"current_epoch": 0, |
|
|
"total_epochs": epochs, |
|
|
"model_id": model_name, |
|
|
"start_time": datetime.utcnow().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
logger.info("Loading model with Unsloth...") |
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=model_name, |
|
|
max_seq_length=2048, |
|
|
dtype=None, |
|
|
load_in_4bit=True |
|
|
) |
|
|
|
|
|
|
|
|
model = FastLanguageModel.get_peft_model( |
|
|
model, |
|
|
r=lora_r, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
|
lora_alpha=16, |
|
|
lora_dropout=0, |
|
|
bias="none", |
|
|
use_gradient_checkpointing=True, |
|
|
random_state=42 |
|
|
) |
|
|
|
|
|
|
|
|
formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca") |
|
|
dataset = Dataset.from_dict({"text": formatted_texts}) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=epochs, |
|
|
per_device_train_batch_size=batch_size, |
|
|
learning_rate=learning_rate, |
|
|
fp16=True, |
|
|
logging_steps=10, |
|
|
save_steps=100, |
|
|
warmup_steps=50 |
|
|
) |
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
train_dataset=dataset, |
|
|
dataset_text_field="text", |
|
|
max_seq_length=2048, |
|
|
args=training_args |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Starting training with Unsloth...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
logger.info(f"Saving model to: {output_dir}") |
|
|
model.save_pretrained(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
self.current_training_state.update({ |
|
|
"is_training": False, |
|
|
"progress": 100.0 |
|
|
}) |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"output_dir": output_dir, |
|
|
"model_name": model_name, |
|
|
"method": "unsloth" |
|
|
} |
|
|
|
|
|
async def fine_tune_with_mlx( |
|
|
self, |
|
|
model_name: str, |
|
|
training_examples: List[Dict[str, Any]], |
|
|
output_dir: str, |
|
|
epochs: int = 3, |
|
|
learning_rate: float = 1e-5, |
|
|
batch_size: int = 4, |
|
|
lora_r: int = 8, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Fine-tune model using MLX (Apple Silicon only, ultra-fast). |
|
|
|
|
|
Optimized for M1/M2/M3 Macs. |
|
|
Uses unified memory for maximum efficiency. |
|
|
""" |
|
|
try: |
|
|
import mlx.core as mx |
|
|
from mlx_lm import load, generate |
|
|
import mlx.optimizers as optim |
|
|
import mlx.nn as nn |
|
|
except ImportError as e: |
|
|
logger.error(f"MLX not installed: {e}") |
|
|
logger.error("Please install: pip install mlx mlx-lm") |
|
|
raise |
|
|
|
|
|
logger.info(f"Starting MLX fine-tuning for model: {model_name}") |
|
|
self.current_training_state.update({ |
|
|
"is_training": True, |
|
|
"progress": 0.0, |
|
|
"current_epoch": 0, |
|
|
"total_epochs": epochs, |
|
|
"model_id": model_name, |
|
|
"start_time": datetime.utcnow().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
logger.warning("MLX fine-tuning is not fully implemented yet") |
|
|
|
|
|
self.current_training_state["is_training"] = False |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": "MLX fine-tuning not yet implemented", |
|
|
"model_name": model_name |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
async def start_training( |
|
|
self, |
|
|
apprentice_model_name: str, |
|
|
domain_id: Optional[str] = None, |
|
|
method: str = "peft", |
|
|
epochs: int = 3, |
|
|
learning_rate: float = 2e-4, |
|
|
batch_size: int = 4, |
|
|
output_name: Optional[str] = None, |
|
|
progress_callback: Optional[Callable] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Main entry point for fine-tuning an apprentice model. |
|
|
|
|
|
Args: |
|
|
apprentice_model_name: HuggingFace model name or path |
|
|
domain_id: Domain to train on (None = all domains) |
|
|
method: Training method ("peft", "unsloth", "mlx") |
|
|
epochs: Number of training epochs |
|
|
learning_rate: Learning rate |
|
|
batch_size: Batch size per device |
|
|
output_name: Custom name for output directory |
|
|
progress_callback: Async callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Training result dictionary |
|
|
""" |
|
|
|
|
|
training_examples = self.load_training_data(domain_id) |
|
|
|
|
|
if not training_examples: |
|
|
return { |
|
|
"success": False, |
|
|
"error": "No training data found" |
|
|
} |
|
|
|
|
|
|
|
|
if output_name is None: |
|
|
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
|
|
output_name = f"apprentice_{domain_id or 'all'}_{timestamp}" |
|
|
|
|
|
output_dir = self.checkpoints_dir / output_name |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if method == "peft": |
|
|
result = await self.fine_tune_with_huggingface_peft( |
|
|
model_name=apprentice_model_name, |
|
|
training_examples=training_examples, |
|
|
output_dir=str(output_dir), |
|
|
epochs=epochs, |
|
|
learning_rate=learning_rate, |
|
|
batch_size=batch_size, |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
elif method == "unsloth": |
|
|
result = await self.fine_tune_with_unsloth( |
|
|
model_name=apprentice_model_name, |
|
|
training_examples=training_examples, |
|
|
output_dir=str(output_dir), |
|
|
epochs=epochs, |
|
|
learning_rate=learning_rate, |
|
|
batch_size=batch_size, |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
elif method == "mlx": |
|
|
result = await self.fine_tune_with_mlx( |
|
|
model_name=apprentice_model_name, |
|
|
training_examples=training_examples, |
|
|
output_dir=str(output_dir), |
|
|
epochs=epochs, |
|
|
learning_rate=learning_rate, |
|
|
batch_size=batch_size, |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
else: |
|
|
return { |
|
|
"success": False, |
|
|
"error": f"Unknown training method: {method}" |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
def get_training_status(self) -> Dict[str, Any]: |
|
|
"""Get current training status.""" |
|
|
return self.current_training_state.copy() |
|
|
|
|
|
def stop_training(self): |
|
|
"""Stop current training (if possible).""" |
|
|
|
|
|
logger.warning("Training interruption not yet implemented") |
|
|
self.current_training_state["is_training"] = False |
|
|
|
|
|
def get_training_metrics(self, checkpoint_dir: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Load training metrics from a checkpoint. |
|
|
""" |
|
|
checkpoint_path = Path(checkpoint_dir) |
|
|
|
|
|
if not checkpoint_path.exists(): |
|
|
return {"error": "Checkpoint not found"} |
|
|
|
|
|
|
|
|
trainer_state_file = checkpoint_path / "trainer_state.json" |
|
|
if trainer_state_file.exists(): |
|
|
with open(trainer_state_file, 'r') as f: |
|
|
trainer_state = json.load(f) |
|
|
return { |
|
|
"log_history": trainer_state.get("log_history", []), |
|
|
"best_metric": trainer_state.get("best_metric"), |
|
|
"best_model_checkpoint": trainer_state.get("best_model_checkpoint") |
|
|
} |
|
|
|
|
|
return {"error": "No metrics found"} |
|
|
|