davidhhmack's picture
Training in progress, step 2000
9486adc verified
Raw
History Blame Contribute Delete
5.34 kB
import os
from presto_learn.hf.creator.sft.load_datasets import load_datasets
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # needed for processor
import logging
from typing import Dict, Optional
import dotenv
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
EvalPrediction,
)
import wandb
from accelerate import Accelerator
from presto_learn.library.gpu_log import log_gpu_params_by_device
from presto_learn.library.term_color import Colors
from dataclasses import asdict
from accelerate import PartialState
from trl import SFTTrainer, SFTConfig
import torch
from presto_env.modal import IS_MODAL_REMOTE
from presto_learn.hf.accelerate_utils import parse_accelerate_args
from presto_learn.hf.creator.sft.collators import collate_fn_image_and_text, collate_fn_image_only
from presto_learn.hf.creator.sft.config import get_trainer_config
from presto_learn.hf.creator.sft.eval import generate_and_eval
from presto_learn.hf.peft import get_peft_configs
from presto_env.env import PrestoEnv
from transformers.trainer_utils import EvalPrediction
# Set up logger
logger = logging.getLogger(__name__)
ADAPTER_DEST_HUB_ID = "Presto-Design/llm_adapter_vectorizer_qwen7b"
def main(project: str, run_id: str):
os.environ["WANDB_PROJECT"] = project
os.environ["TOKENIZERS_PARALLELISM"] = "false"
latest_adapter = (
"/projects/creator-single-apres/20250326-13-41-fond-guanaco/checkpoint-50000"
)
accelerator = Accelerator()
model_configs = get_peft_configs()
torch_dtype = torch.bfloat16 if IS_MODAL_REMOTE() else torch.float32
config = get_trainer_config()
processor = AutoProcessor.from_pretrained(config.base_model, use_fast=True)
dataset = load_datasets(processor, config, accelerator)
# Was getting this error:
# Expected scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || scalar_type == ScalarType::Bool to be true, but got false.
# Swapped to float32 which seems to fix, later will figure out how to get 16bit training working
device_string = PartialState().process_index
# Use device_map=None to ensure consistent loading across all ranks
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config.base_model,
torch_dtype=torch_dtype,
attn_implementation=model_configs.attn_implementation,
device_map={"": device_string},
)
if IS_MODAL_REMOTE():
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
latest_adapter,
is_trainable=True,
)
has_adapter = True
else:
has_adapter = False
if accelerator.is_main_process:
Colors.print_dict(asdict(config), "Trainer Config:")
training_args = SFTConfig(
output_dir=PrestoEnv.run_folder(project, run_id),
run_name=run_id,
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.eval_batch_size,
gradient_accumulation_steps=1,
eval_strategy="steps",
eval_steps=config.eval_steps,
save_strategy="steps",
save_steps=config.eval_steps,
metric_for_best_model="loss",
save_total_limit=5,
logging_steps=config.logging_steps,
logging_dir="./logs",
learning_rate=config.learning_rate,
push_to_hub=True if IS_MODAL_REMOTE() else False,
hub_model_id=ADAPTER_DEST_HUB_ID if IS_MODAL_REMOTE() else None,
use_liger=model_configs.use_liger_kernel,
optim=model_configs.optimizer_name,
report_to=["wandb"] if IS_MODAL_REMOTE() else [],
dataset_kwargs={
"skip_prepare_dataset": True # This means no packing or truncation is done
}, # We need to manually prep image data
remove_unused_columns=False, # We need them so we can process in colation
fp16=torch_dtype == torch.float16,
bf16=torch_dtype == torch.bfloat16,
)
data_collator = lambda x: collate_fn_image_and_text(x, processor, config)
# if accelerator.is_main_process:
log_gpu_params_by_device("Model", model)
def preprocess_logits_for_metrics(logits, labels) -> torch.Tensor:
# For memory reasons we don't want to gather the logits
# So lets argmax and compare to the labels
logits = logits[0] # not sure what the second item is [[-90]]
return logits.argmax(dim=-1) == labels
trainer = SFTTrainer(
model=model,
peft_config=(model_configs.peft_config if not has_adapter else None),
args=training_args,
data_collator=data_collator,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=processor.tokenizer,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if "wandb" in training_args.report_to and trainer.accelerator.is_main_process:
wandb.init(
project=project, name=training_args.run_name, config=asdict(training_args)
)
trainer.train()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
dotenv.load_dotenv()
args = parse_accelerate_args(default_project="creator-image-assets")
main(args.project, args.run_id)