|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, TrainingArguments, Trainer |
|
|
from transformers import DataCollatorForLanguageModeling |
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
import json |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import bitsandbytes as bnb |
|
|
from typing import Optional, Dict, Any |
|
|
import wandb |
|
|
from safetensors.torch import save_file, load_file |
|
|
import shutil |
|
|
|
|
|
from model import TinyStateForCausalLM, TinyStateConfig |
|
|
|
|
|
def load_model_and_tokenizer(model_path: Optional[str] = None, use_4bit: bool = True): |
|
|
"""Load the TinyState model and tokenizer""" |
|
|
|
|
|
|
|
|
config_path = None |
|
|
if model_path and os.path.exists(f"{model_path}/config.json"): |
|
|
config_path = f"{model_path}/config.json" |
|
|
elif model_path and os.path.exists(f"{model_path}/configuration.json"): |
|
|
config_path = f"{model_path}/configuration.json" |
|
|
|
|
|
if config_path: |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict = json.load(f) |
|
|
config = TinyStateConfig(**config_dict) |
|
|
print(f"Loaded config from {config_path}") |
|
|
else: |
|
|
config = TinyStateConfig() |
|
|
print("Using default TinyState config") |
|
|
|
|
|
|
|
|
if model_path and os.path.exists(model_path): |
|
|
print(f"Loading model from {model_path}") |
|
|
try: |
|
|
|
|
|
model = load_model_from_safetensors(model_path, config) |
|
|
except Exception as e: |
|
|
print(f"Failed to load from safetensors: {e}") |
|
|
print("Initializing new model instead") |
|
|
model = TinyStateForCausalLM(config) |
|
|
else: |
|
|
print("Initializing new TinyState model") |
|
|
model = TinyStateForCausalLM(config) |
|
|
|
|
|
|
|
|
if use_4bit: |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B", trust_remote_code=True) |
|
|
except: |
|
|
try: |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained("Qwen/Qwen2-7B") |
|
|
except: |
|
|
|
|
|
from transformers import PreTrainedTokenizer |
|
|
tokenizer = PreTrainedTokenizer() |
|
|
tokenizer.pad_token = "<|endoftext|>" |
|
|
tokenizer.eos_token = "<|endoftext|>" |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
return model, tokenizer, config |
|
|
|
|
|
def load_model_from_safetensors(model_path: str, config): |
|
|
"""Load model from safetensors files (single or chunked)""" |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
index_path = f"{model_path}/model.safetensors.index.json" |
|
|
if os.path.exists(index_path): |
|
|
print("Loading chunked safetensors model...") |
|
|
with open(index_path, 'r') as f: |
|
|
index = json.load(f) |
|
|
|
|
|
state_dict = {} |
|
|
for shard_file in set(index["weight_map"].values()): |
|
|
shard_path = f"{model_path}/{shard_file}" |
|
|
if os.path.exists(shard_path): |
|
|
shard_dict = load_file(shard_path) |
|
|
state_dict.update(shard_dict) |
|
|
|
|
|
model = TinyStateForCausalLM(config) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
|
|
|
|
|
|
single_path = f"{model_path}/model.safetensors" |
|
|
if os.path.exists(single_path): |
|
|
print("Loading single safetensors model...") |
|
|
state_dict = load_file(single_path) |
|
|
model = TinyStateForCausalLM(config) |
|
|
model.load_state_dict(state_dict) |
|
|
return model |
|
|
|
|
|
raise FileNotFoundError("No safetensors files found") |
|
|
|
|
|
def setup_lora_training(model, r: int = 64, alpha: int = 16, dropout: float = 0.1): |
|
|
"""Setup LoRA for efficient training""" |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=r, |
|
|
lora_alpha=alpha, |
|
|
target_modules=[ |
|
|
"q_proj", "k_proj", "v_proj", "o_proj", |
|
|
"gate_proj", "up_proj", "down_proj" |
|
|
], |
|
|
lora_dropout=dropout, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
model.print_trainable_parameters() |
|
|
return model |
|
|
|
|
|
def load_training_data( |
|
|
data_path: str = "data/dataset.jsonl", |
|
|
tokenizer=None, |
|
|
max_length: int = 2048, |
|
|
streaming: bool = False |
|
|
): |
|
|
"""Load and preprocess training data""" |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=max_length, |
|
|
return_tensors=None, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
if data_path.endswith('.jsonl') or data_path.endswith('.json'): |
|
|
dataset = load_dataset('json', data_files=data_path, streaming=streaming, split='train') |
|
|
else: |
|
|
dataset = load_dataset(data_path, streaming=streaming, split='train') |
|
|
except Exception as e: |
|
|
print(f"Error loading dataset: {e}") |
|
|
|
|
|
from datasets import Dataset |
|
|
dummy_data = {"text": ["This is dummy data for testing."] * 100} |
|
|
dataset = Dataset.from_dict(dummy_data) |
|
|
|
|
|
|
|
|
try: |
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset.column_names if hasattr(dataset, 'column_names') else ["text"], |
|
|
desc="Tokenizing dataset", |
|
|
) |
|
|
except: |
|
|
|
|
|
def simple_tokenize(examples): |
|
|
return {"input_ids": [tokenizer.encode(text, max_length=max_length, truncation=True) for text in examples["text"]]} |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
simple_tokenize, |
|
|
batched=True, |
|
|
desc="Tokenizing dataset (fallback)", |
|
|
) |
|
|
|
|
|
return tokenized_dataset |
|
|
|
|
|
class DistillationTrainer(Trainer): |
|
|
"""Custom trainer with knowledge distillation""" |
|
|
|
|
|
def __init__(self, teacher_model=None, distillation_alpha=0.5, temperature=2.0, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.teacher_model = teacher_model |
|
|
self.distillation_alpha = distillation_alpha |
|
|
self.temperature = temperature |
|
|
|
|
|
if self.teacher_model is not None: |
|
|
self.teacher_model.eval() |
|
|
for param in self.teacher_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
|
|
|
|
outputs = model(**inputs) |
|
|
loss = outputs.loss |
|
|
|
|
|
|
|
|
if self.teacher_model is not None: |
|
|
with torch.no_grad(): |
|
|
teacher_outputs = self.teacher_model(**inputs) |
|
|
|
|
|
|
|
|
student_logits = outputs.logits / self.temperature |
|
|
teacher_logits = teacher_outputs.logits / self.temperature |
|
|
|
|
|
|
|
|
distill_loss = nn.KLDivLoss(reduction='batchmean')( |
|
|
nn.LogSoftmax(dim=-1)(student_logits), |
|
|
nn.Softmax(dim=-1)(teacher_logits) |
|
|
) * (self.temperature ** 2) |
|
|
|
|
|
|
|
|
loss = self.distillation_alpha * loss + (1 - self.distillation_alpha) * distill_loss |
|
|
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
def train_model( |
|
|
model_path: Optional[str] = None, |
|
|
data_path: str = "data/dataset.jsonl", |
|
|
output_dir: str = "./tinystate-19b-a9b", |
|
|
use_lora: bool = True, |
|
|
use_4bit: bool = True, |
|
|
use_distillation: bool = False, |
|
|
teacher_model_path: Optional[str] = None, |
|
|
**training_kwargs |
|
|
): |
|
|
"""Main training function""" |
|
|
|
|
|
print("=== TinyState-19B-A9B Training ===") |
|
|
|
|
|
|
|
|
print("1. Loading model and tokenizer...") |
|
|
model, tokenizer, config = load_model_and_tokenizer(model_path, use_4bit=use_4bit) |
|
|
|
|
|
|
|
|
if use_lora: |
|
|
print("2. Setting up LoRA training...") |
|
|
model = setup_lora_training(model) |
|
|
|
|
|
|
|
|
print("3. Loading training data...") |
|
|
train_dataset = load_training_data(data_path, tokenizer) |
|
|
|
|
|
|
|
|
teacher_model = None |
|
|
if use_distillation and teacher_model_path: |
|
|
print("4. Loading teacher model for distillation...") |
|
|
try: |
|
|
teacher_config = TinyStateConfig() |
|
|
teacher_model = TinyStateForCausalLM.from_pretrained(teacher_model_path, config=teacher_config) |
|
|
teacher_model.eval() |
|
|
for param in teacher_model.parameters(): |
|
|
param.requires_grad = False |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load teacher model: {e}") |
|
|
teacher_model = None |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False, |
|
|
) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=training_kwargs.get("num_train_epochs", 3), |
|
|
per_device_train_batch_size=training_kwargs.get("per_device_train_batch_size", 1), |
|
|
gradient_accumulation_steps=training_kwargs.get("gradient_accumulation_steps", 16), |
|
|
warmup_steps=training_kwargs.get("warmup_steps", 100), |
|
|
logging_steps=training_kwargs.get("logging_steps", 10), |
|
|
save_steps=training_kwargs.get("save_steps", 500), |
|
|
learning_rate=training_kwargs.get("learning_rate", 2e-4), |
|
|
fp16=True, |
|
|
gradient_checkpointing=True, |
|
|
lr_scheduler_type="cosine", |
|
|
logging_dir=f"{output_dir}/logs", |
|
|
logging_strategy="steps", |
|
|
save_strategy="steps", |
|
|
save_total_limit=3, |
|
|
dataloader_num_workers=4, |
|
|
remove_unused_columns=True, |
|
|
report_to="wandb" if training_kwargs.get("use_wandb", False) else None, |
|
|
**{k: v for k, v in training_kwargs.items() if k not in [ |
|
|
"num_train_epochs", "per_device_train_batch_size", "gradient_accumulation_steps", |
|
|
"warmup_steps", "logging_steps", "save_steps", "learning_rate", "use_wandb" |
|
|
]} |
|
|
) |
|
|
|
|
|
|
|
|
print("5. Setting up trainer...") |
|
|
trainer = DistillationTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
teacher_model=teacher_model, |
|
|
distillation_alpha=training_kwargs.get("distillation_alpha", 0.5), |
|
|
temperature=training_kwargs.get("temperature", 2.0), |
|
|
) |
|
|
|
|
|
|
|
|
print("6. Starting training...") |
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("7. Saving model...") |
|
|
save_model_safetensors(trainer.model, tokenizer, config, output_dir) |
|
|
|
|
|
print(f"Training completed! Model saved to {output_dir}") |
|
|
return trainer.model |
|
|
|
|
|
def save_model_safetensors(model, tokenizer, config, output_dir: str, max_shard_size: str = "4GB"): |
|
|
"""Save model in safetensors format (single or chunked)""" |
|
|
import json |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
config_dict = config.__dict__.copy() |
|
|
|
|
|
config_dict = {k: v for k, v in config_dict.items() if not callable(v) and not isinstance(v, (torch device))} |
|
|
with open(f"{output_dir}/config.json", 'w') as f: |
|
|
json.dump(config_dict, f, indent=2) |
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
state_dict = model.state_dict() |
|
|
|
|
|
|
|
|
total_size = sum(param.numel() * param.element_size() for param in state_dict.values()) |
|
|
print(f"Model size: {total_size / (1024**3):.2f} GB") |
|
|
|
|
|
|
|
|
if total_size < 4 * (1024**3): |
|
|
save_file(state_dict, f"{output_dir}/model.safetensors") |
|
|
print(f"Model saved as single safetensors file") |
|
|
else: |
|
|
|
|
|
save_model_chunked(state_dict, output_dir, max_shard_size) |
|
|
|
|
|
print(f"Model saved in safetensors format to {output_dir}") |
|
|
|
|
|
def save_model_chunked(state_dict, output_dir: str, max_shard_size: str = "4GB"): |
|
|
"""Save model as chunked safetensors files""" |
|
|
import json |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
if max_shard_size.endswith("GB"): |
|
|
max_size = int(max_shard_size[:-2]) * (1024**3) |
|
|
elif max_shard_size.endswith("MB"): |
|
|
max_size = int(max_shard_size[:-2]) * (1024**2) |
|
|
else: |
|
|
max_size = 4 * (1024**3) |
|
|
|
|
|
shards = [] |
|
|
current_shard = {} |
|
|
current_size = 0 |
|
|
shard_idx = 1 |
|
|
|
|
|
|
|
|
sorted_params = sorted(state_dict.items(), key=lambda x: x[1].numel() * x[1].element_size(), reverse=True) |
|
|
|
|
|
for name, param in sorted_params: |
|
|
param_size = param.numel() * param.element_size() |
|
|
|
|
|
|
|
|
if current_size + param_size > max_size and current_shard: |
|
|
shard_filename = f"model-{shard_idx:05d}-of-00008.safetensors" |
|
|
save_file(current_shard, f"{output_dir}/{shard_filename}") |
|
|
shards.append((shard_filename, dict(current_shard))) |
|
|
print(f"Saved shard {shard_idx}: {len(current_shard)} parameters") |
|
|
|
|
|
|
|
|
current_shard = {} |
|
|
current_size = 0 |
|
|
shard_idx += 1 |
|
|
|
|
|
|
|
|
current_shard[name] = param |
|
|
current_size += param_size |
|
|
|
|
|
|
|
|
if current_shard: |
|
|
shard_filename = f"model-{shard_idx:05d}-of-00008.safetensors" |
|
|
save_file(current_shard, f"{output_dir}/{shard_filename}") |
|
|
shards.append((shard_filename, dict(current_shard))) |
|
|
print(f"Saved shard {shard_idx}: {len(current_shard)} parameters") |
|
|
|
|
|
|
|
|
index = { |
|
|
"metadata": { |
|
|
"total_size": sum(p.numel() * p.element_size() for p in state_dict.values()) |
|
|
}, |
|
|
"weight_map": {} |
|
|
} |
|
|
|
|
|
for shard_filename, shard_dict in shards: |
|
|
for param_name in shard_dict.keys(): |
|
|
index["weight_map"][param_name] = shard_filename |
|
|
|
|
|
with open(f"{output_dir}/model.safetensors.index.json", 'w') as f: |
|
|
json.dump(index, f, indent=2) |
|
|
|
|
|
print(f"Saved {len(shards)} chunked safetensors files with index") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
pass |