import os from presto_learn.hf.creator.sft.load_datasets import load_datasets os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # needed for processor import logging from typing import Dict, Optional import dotenv from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, EvalPrediction, ) import wandb from accelerate import Accelerator from presto_learn.library.gpu_log import log_gpu_params_by_device from presto_learn.library.term_color import Colors from dataclasses import asdict from accelerate import PartialState from trl import SFTTrainer, SFTConfig import torch from presto_env.modal import IS_MODAL_REMOTE from presto_learn.hf.accelerate_utils import parse_accelerate_args from presto_learn.hf.creator.sft.collators import collate_fn_image_and_text, collate_fn_image_only from presto_learn.hf.creator.sft.config import get_trainer_config from presto_learn.hf.creator.sft.eval import generate_and_eval from presto_learn.hf.peft import get_peft_configs from presto_env.env import PrestoEnv from transformers.trainer_utils import EvalPrediction # Set up logger logger = logging.getLogger(__name__) ADAPTER_DEST_HUB_ID = "Presto-Design/llm_adapter_vectorizer_qwen7b" def main(project: str, run_id: str): os.environ["WANDB_PROJECT"] = project os.environ["TOKENIZERS_PARALLELISM"] = "false" latest_adapter = ( "/projects/creator-single-apres/20250326-13-41-fond-guanaco/checkpoint-50000" ) accelerator = Accelerator() model_configs = get_peft_configs() torch_dtype = torch.bfloat16 if IS_MODAL_REMOTE() else torch.float32 config = get_trainer_config() processor = AutoProcessor.from_pretrained(config.base_model, use_fast=True) dataset = load_datasets(processor, config, accelerator) # Was getting this error: # Expected scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || scalar_type == ScalarType::Bool to be true, but got false. # Swapped to float32 which seems to fix, later will figure out how to get 16bit training working device_string = PartialState().process_index # Use device_map=None to ensure consistent loading across all ranks model = Qwen2_5_VLForConditionalGeneration.from_pretrained( config.base_model, torch_dtype=torch_dtype, attn_implementation=model_configs.attn_implementation, device_map={"": device_string}, ) if IS_MODAL_REMOTE(): from peft import PeftModel model = PeftModel.from_pretrained( model, latest_adapter, is_trainable=True, ) has_adapter = True else: has_adapter = False if accelerator.is_main_process: Colors.print_dict(asdict(config), "Trainer Config:") training_args = SFTConfig( output_dir=PrestoEnv.run_folder(project, run_id), run_name=run_id, num_train_epochs=config.epochs, per_device_train_batch_size=config.batch_size, per_device_eval_batch_size=config.eval_batch_size, gradient_accumulation_steps=1, eval_strategy="steps", eval_steps=config.eval_steps, save_strategy="steps", save_steps=config.eval_steps, metric_for_best_model="loss", save_total_limit=5, logging_steps=config.logging_steps, logging_dir="./logs", learning_rate=config.learning_rate, push_to_hub=True if IS_MODAL_REMOTE() else False, hub_model_id=ADAPTER_DEST_HUB_ID if IS_MODAL_REMOTE() else None, use_liger=model_configs.use_liger_kernel, optim=model_configs.optimizer_name, report_to=["wandb"] if IS_MODAL_REMOTE() else [], dataset_kwargs={ "skip_prepare_dataset": True # This means no packing or truncation is done }, # We need to manually prep image data remove_unused_columns=False, # We need them so we can process in colation fp16=torch_dtype == torch.float16, bf16=torch_dtype == torch.bfloat16, ) data_collator = lambda x: collate_fn_image_and_text(x, processor, config) # if accelerator.is_main_process: log_gpu_params_by_device("Model", model) def preprocess_logits_for_metrics(logits, labels) -> torch.Tensor: # For memory reasons we don't want to gather the logits # So lets argmax and compare to the labels logits = logits[0] # not sure what the second item is [[-90]] return logits.argmax(dim=-1) == labels trainer = SFTTrainer( model=model, peft_config=(model_configs.peft_config if not has_adapter else None), args=training_args, data_collator=data_collator, train_dataset=dataset["train"], eval_dataset=dataset["test"], processing_class=processor.tokenizer, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) if "wandb" in training_args.report_to and trainer.accelerator.is_main_process: wandb.init( project=project, name=training_args.run_name, config=asdict(training_args) ) trainer.train() if __name__ == "__main__": logging.basicConfig(level=logging.INFO) dotenv.load_dotenv() args = parse_accelerate_args(default_project="creator-image-assets") main(args.project, args.run_id)