shivik-m2.2 / train_aries.py
ziadrone's picture
Upload Shivik-M2 with merges.txt (clean)
054c77e verified
# train_aries.py
# Skeleton training pipeline for:
# - SFT (supervised fine-tuning)
# - hooks to plug GRPO/TRL reward models (placeholders provided)
#
# Usage:
# export HF_TOKEN="hf_xxx"
# python train_aries.py --data /path/to/data.jsonl --output_dir /path/to/out --epochs 3 --batch 2
import os, argparse, json
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset
def load_tokenizer_and_model(repo_or_local):
tok = AutoTokenizer.from_pretrained(repo_or_local, trust_remote_code=True, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(repo_or_local, trust_remote_code=True)
return tok, model
def prepare_dataset(path, tok, max_length=512):
# expects jsonl with {"prompt": "...", "response": "..."}
ds = load_dataset('json', data_files={'train': str(path)}, split='train')
def map_fn(x):
text = x.get('prompt','') + '\n' + x.get('response','')
return tok(text, truncation=True, max_length=max_length)
ds = ds.map(map_fn, batched=False)
ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
return ds
def main():
p = argparse.ArgumentParser()
p.add_argument('--data', required=True)
p.add_argument('--repo', default='.' , help='local folder or HF repo id')
p.add_argument('--output_dir', default='./out')
p.add_argument('--epochs', type=int, default=1)
p.add_argument('--batch', type=int, default=2)
args = p.parse_args()
tok, model = load_tokenizer_and_model(args.repo)
ds = prepare_dataset(args.data, tok)
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch,
num_train_epochs=args.epochs,
bf16=torch.cuda.is_available(),
fp16=torch.cuda.is_available(),
logging_steps=10,
save_strategy='epoch',
push_to_hub=False
)
# Basic SFT trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds,
tokenizer=tok
)
trainer.train()
# === Hooks: attach GRPO/TRL ===
# After SFT completes, you may want to:
# 1) Initialize reward model and KTO/GRPO loop (placeholder)
# 2) Use `trl`'s PPOTrainer or custom GRPO trainer
# Example (pseudo):
# from trl import PPOTrainer
# reward_fn = lambda queries, generations: compute_rewards(queries, generations, reward_model)
# ppo_trainer = PPOTrainer(...)
# ppo_trainer.train()
print("Done SFT. Model checkpoint in", args.output_dir)
if __name__ == '__main__':
main()