TinyState-19B-A9B / train.py
GeminiFan207's picture
Create train.py
1e90562 verified
# train.py - COMPLETE with All Fixes
import torch
import torch.nn as nn
from transformers import AutoTokenizer, TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import json
import os
from tqdm import tqdm
import bitsandbytes as bnb
from typing import Optional, Dict, Any
import wandb
from safetensors.torch import save_file, load_file
import shutil
from model import TinyStateForCausalLM, TinyStateConfig
def load_model_and_tokenizer(model_path: Optional[str] = None, use_4bit: bool = True):
"""Load the TinyState model and tokenizer"""
# Load configuration
config_path = None
if model_path and os.path.exists(f"{model_path}/config.json"):
config_path = f"{model_path}/config.json"
elif model_path and os.path.exists(f"{model_path}/configuration.json"):
config_path = f"{model_path}/configuration.json"
if config_path:
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = TinyStateConfig(**config_dict)
print(f"Loaded config from {config_path}")
else:
config = TinyStateConfig()
print("Using default TinyState config")
# Initialize model
if model_path and os.path.exists(model_path):
print(f"Loading model from {model_path}")
try:
# Try loading from safetensors first
model = load_model_from_safetensors(model_path, config)
except Exception as e:
print(f"Failed to load from safetensors: {e}")
print("Initializing new model instead")
model = TinyStateForCausalLM(config)
else:
print("Initializing new TinyState model")
model = TinyStateForCausalLM(config)
# Setup quantization if requested
if use_4bit:
model = prepare_model_for_kbit_training(model)
# Load tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B", trust_remote_code=True)
except:
try:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained("Qwen/Qwen2-7B")
except:
# Create basic tokenizer
from transformers import PreTrainedTokenizer
tokenizer = PreTrainedTokenizer()
tokenizer.pad_token = "<|endoftext|>"
tokenizer.eos_token = "<|endoftext|>"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, config
def load_model_from_safetensors(model_path: str, config):
"""Load model from safetensors files (single or chunked)"""
from safetensors.torch import load_file
# Check for index file (chunked)
index_path = f"{model_path}/model.safetensors.index.json"
if os.path.exists(index_path):
print("Loading chunked safetensors model...")
with open(index_path, 'r') as f:
index = json.load(f)
state_dict = {}
for shard_file in set(index["weight_map"].values()):
shard_path = f"{model_path}/{shard_file}"
if os.path.exists(shard_path):
shard_dict = load_file(shard_path)
state_dict.update(shard_dict)
model = TinyStateForCausalLM(config)
model.load_state_dict(state_dict)
return model
# Check for single safetensors file
single_path = f"{model_path}/model.safetensors"
if os.path.exists(single_path):
print("Loading single safetensors model...")
state_dict = load_file(single_path)
model = TinyStateForCausalLM(config)
model.load_state_dict(state_dict)
return model
raise FileNotFoundError("No safetensors files found")
def setup_lora_training(model, r: int = 64, alpha: int = 16, dropout: float = 0.1):
"""Setup LoRA for efficient training"""
lora_config = LoraConfig(
r=r,
lora_alpha=alpha,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def load_training_data(
data_path: str = "data/dataset.jsonl",
tokenizer=None,
max_length: int = 2048,
streaming: bool = False
):
"""Load and preprocess training data"""
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding=False,
max_length=max_length,
return_tensors=None,
)
# Load dataset
try:
if data_path.endswith('.jsonl') or data_path.endswith('.json'):
dataset = load_dataset('json', data_files=data_path, streaming=streaming, split='train')
else:
dataset = load_dataset(data_path, streaming=streaming, split='train')
except Exception as e:
print(f"Error loading dataset: {e}")
# Create dummy dataset for testing
from datasets import Dataset
dummy_data = {"text": ["This is dummy data for testing."] * 100}
dataset = Dataset.from_dict(dummy_data)
# Tokenize dataset
try:
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names if hasattr(dataset, 'column_names') else ["text"],
desc="Tokenizing dataset",
)
except:
# Fallback tokenization
def simple_tokenize(examples):
return {"input_ids": [tokenizer.encode(text, max_length=max_length, truncation=True) for text in examples["text"]]}
tokenized_dataset = dataset.map(
simple_tokenize,
batched=True,
desc="Tokenizing dataset (fallback)",
)
return tokenized_dataset
class DistillationTrainer(Trainer):
"""Custom trainer with knowledge distillation"""
def __init__(self, teacher_model=None, distillation_alpha=0.5, temperature=2.0, **kwargs):
super().__init__(**kwargs)
self.teacher_model = teacher_model
self.distillation_alpha = distillation_alpha
self.temperature = temperature
if self.teacher_model is not None:
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.requires_grad = False
def compute_loss(self, model, inputs, return_outputs=False):
# Standard loss
outputs = model(**inputs)
loss = outputs.loss
# Distillation loss (if teacher model is provided)
if self.teacher_model is not None:
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
# Soften probabilities
student_logits = outputs.logits / self.temperature
teacher_logits = teacher_outputs.logits / self.temperature
# Compute distillation loss
distill_loss = nn.KLDivLoss(reduction='batchmean')(
nn.LogSoftmax(dim=-1)(student_logits),
nn.Softmax(dim=-1)(teacher_logits)
) * (self.temperature ** 2)
# Combine losses
loss = self.distillation_alpha * loss + (1 - self.distillation_alpha) * distill_loss
return (loss, outputs) if return_outputs else loss
def train_model(
model_path: Optional[str] = None,
data_path: str = "data/dataset.jsonl",
output_dir: str = "./tinystate-19b-a9b",
use_lora: bool = True,
use_4bit: bool = True,
use_distillation: bool = False,
teacher_model_path: Optional[str] = None,
**training_kwargs
):
"""Main training function"""
print("=== TinyState-19B-A9B Training ===")
# Load model and tokenizer
print("1. Loading model and tokenizer...")
model, tokenizer, config = load_model_and_tokenizer(model_path, use_4bit=use_4bit)
# Setup LoRA if requested
if use_lora:
print("2. Setting up LoRA training...")
model = setup_lora_training(model)
# Load training data
print("3. Loading training data...")
train_dataset = load_training_data(data_path, tokenizer)
# Load teacher model for distillation (if requested)
teacher_model = None
if use_distillation and teacher_model_path:
print("4. Loading teacher model for distillation...")
try:
teacher_config = TinyStateConfig()
teacher_model = TinyStateForCausalLM.from_pretrained(teacher_model_path, config=teacher_config)
teacher_model.eval()
for param in teacher_model.parameters():
param.requires_grad = False
except Exception as e:
print(f"Warning: Could not load teacher model: {e}")
teacher_model = None
# Setup data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Setup training arguments
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=training_kwargs.get("num_train_epochs", 3),
per_device_train_batch_size=training_kwargs.get("per_device_train_batch_size", 1),
gradient_accumulation_steps=training_kwargs.get("gradient_accumulation_steps", 16),
warmup_steps=training_kwargs.get("warmup_steps", 100),
logging_steps=training_kwargs.get("logging_steps", 10),
save_steps=training_kwargs.get("save_steps", 500),
learning_rate=training_kwargs.get("learning_rate", 2e-4),
fp16=True,
gradient_checkpointing=True, # Memory optimization
lr_scheduler_type="cosine",
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
save_strategy="steps",
save_total_limit=3,
dataloader_num_workers=4,
remove_unused_columns=True,
report_to="wandb" if training_kwargs.get("use_wandb", False) else None,
**{k: v for k, v in training_kwargs.items() if k not in [
"num_train_epochs", "per_device_train_batch_size", "gradient_accumulation_steps",
"warmup_steps", "logging_steps", "save_steps", "learning_rate", "use_wandb"
]}
)
# Initialize trainer
print("5. Setting up trainer...")
trainer = DistillationTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
teacher_model=teacher_model,
distillation_alpha=training_kwargs.get("distillation_alpha", 0.5),
temperature=training_kwargs.get("temperature", 2.0),
)
# Start training
print("6. Starting training...")
trainer.train()
# Save model
print("7. Saving model...")
save_model_safetensors(trainer.model, tokenizer, config, output_dir)
print(f"Training completed! Model saved to {output_dir}")
return trainer.model
def save_model_safetensors(model, tokenizer, config, output_dir: str, max_shard_size: str = "4GB"):
"""Save model in safetensors format (single or chunked)"""
import json
from safetensors.torch import save_file
os.makedirs(output_dir, exist_ok=True)
# Save config
config_dict = config.__dict__.copy()
# Remove non-serializable items
config_dict = {k: v for k, v in config_dict.items() if not callable(v) and not isinstance(v, (torch device))}
with open(f"{output_dir}/config.json", 'w') as f:
json.dump(config_dict, f, indent=2)
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Get state dict
state_dict = model.state_dict()
# Calculate total size
total_size = sum(param.numel() * param.element_size() for param in state_dict.values())
print(f"Model size: {total_size / (1024**3):.2f} GB")
# If model is small enough, save as single file
if total_size < 4 * (1024**3): # Less than 4GB
save_file(state_dict, f"{output_dir}/model.safetensors")
print(f"Model saved as single safetensors file")
else:
# Save as chunked files
save_model_chunked(state_dict, output_dir, max_shard_size)
print(f"Model saved in safetensors format to {output_dir}")
def save_model_chunked(state_dict, output_dir: str, max_shard_size: str = "4GB"):
"""Save model as chunked safetensors files"""
import json
from safetensors.torch import save_file
# Convert max_shard_size to bytes
if max_shard_size.endswith("GB"):
max_size = int(max_shard_size[:-2]) * (1024**3)
elif max_shard_size.endswith("MB"):
max_size = int(max_shard_size[:-2]) * (1024**2)
else:
max_size = 4 * (1024**3) # Default 4GB
shards = []
current_shard = {}
current_size = 0
shard_idx = 1
# Sort parameters by size (descending) for better packing
sorted_params = sorted(state_dict.items(), key=lambda x: x[1].numel() * x[1].element_size(), reverse=True)
for name, param in sorted_params:
param_size = param.numel() * param.element_size()
# If adding this parameter would exceed shard size, save current shard
if current_size + param_size > max_size and current_shard:
shard_filename = f"model-{shard_idx:05d}-of-00008.safetensors"
save_file(current_shard, f"{output_dir}/{shard_filename}")
shards.append((shard_filename, dict(current_shard)))
print(f"Saved shard {shard_idx}: {len(current_shard)} parameters")
# Start new shard
current_shard = {}
current_size = 0
shard_idx += 1
# Add parameter to current shard
current_shard[name] = param
current_size += param_size
# Save final shard
if current_shard:
shard_filename = f"model-{shard_idx:05d}-of-00008.safetensors"
save_file(current_shard, f"{output_dir}/{shard_filename}")
shards.append((shard_filename, dict(current_shard)))
print(f"Saved shard {shard_idx}: {len(current_shard)} parameters")
# Create index file
index = {
"metadata": {
"total_size": sum(p.numel() * p.element_size() for p in state_dict.values())
},
"weight_map": {}
}
for shard_filename, shard_dict in shards:
for param_name in shard_dict.keys():
index["weight_map"][param_name] = shard_filename
with open(f"{output_dir}/model.safetensors.index.json", 'w') as f:
json.dump(index, f, indent=2)
print(f"Saved {len(shards)} chunked safetensors files with index")
# Example usage
if __name__ == "__main__":
# This would be the actual training command
pass