File size: 5,336 Bytes
9486adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os

from presto_learn.hf.creator.sft.load_datasets import load_datasets

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # needed for processor

import logging

from typing import Dict, Optional
import dotenv
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    EvalPrediction,
)
import wandb
from accelerate import Accelerator

from presto_learn.library.gpu_log import log_gpu_params_by_device

from presto_learn.library.term_color import Colors
from dataclasses import asdict

from accelerate import PartialState
from trl import SFTTrainer, SFTConfig
import torch

from presto_env.modal import IS_MODAL_REMOTE
from presto_learn.hf.accelerate_utils import parse_accelerate_args
from presto_learn.hf.creator.sft.collators import collate_fn_image_and_text, collate_fn_image_only
from presto_learn.hf.creator.sft.config import get_trainer_config
from presto_learn.hf.creator.sft.eval import generate_and_eval
from presto_learn.hf.peft import get_peft_configs
from presto_env.env import PrestoEnv
from transformers.trainer_utils import EvalPrediction


# Set up logger
logger = logging.getLogger(__name__)

ADAPTER_DEST_HUB_ID = "Presto-Design/llm_adapter_vectorizer_qwen7b"

def main(project: str, run_id: str):

    os.environ["WANDB_PROJECT"] = project
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    latest_adapter = (
        "/projects/creator-single-apres/20250326-13-41-fond-guanaco/checkpoint-50000"
    )

    accelerator = Accelerator()
    model_configs = get_peft_configs()
    torch_dtype = torch.bfloat16 if IS_MODAL_REMOTE() else torch.float32
    config = get_trainer_config()

    processor = AutoProcessor.from_pretrained(config.base_model, use_fast=True)

    dataset = load_datasets(processor, config, accelerator)

    # Was getting this error:
    # Expected scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || scalar_type == ScalarType::Bool to be true, but got false.
    # Swapped to float32 which seems to fix, later will figure out how to get 16bit training working

    device_string = PartialState().process_index

    # Use device_map=None to ensure consistent loading across all ranks
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        config.base_model,
        torch_dtype=torch_dtype,
        attn_implementation=model_configs.attn_implementation,
        device_map={"": device_string},
    )

    if IS_MODAL_REMOTE():
        from peft import PeftModel

        model = PeftModel.from_pretrained(
            model,
            latest_adapter,
            is_trainable=True,
        )
        has_adapter = True
    else:
        has_adapter = False

    if accelerator.is_main_process:
        Colors.print_dict(asdict(config), "Trainer Config:")

    training_args = SFTConfig(
        output_dir=PrestoEnv.run_folder(project, run_id),
        run_name=run_id,
        num_train_epochs=config.epochs,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.eval_batch_size,
        gradient_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=config.eval_steps,
        save_strategy="steps",
        save_steps=config.eval_steps,
        metric_for_best_model="loss",
        save_total_limit=5,
        logging_steps=config.logging_steps,
        logging_dir="./logs",
        learning_rate=config.learning_rate,
        
        push_to_hub=True if IS_MODAL_REMOTE() else False,
        hub_model_id=ADAPTER_DEST_HUB_ID if IS_MODAL_REMOTE() else None,
        use_liger=model_configs.use_liger_kernel,
        optim=model_configs.optimizer_name,
        report_to=["wandb"] if IS_MODAL_REMOTE() else [],
        dataset_kwargs={
            "skip_prepare_dataset": True  # This means no packing or truncation is done
        },  # We need to manually prep image data
        remove_unused_columns=False,  # We need them so we can process in colation
        fp16=torch_dtype == torch.float16,
        bf16=torch_dtype == torch.bfloat16,
    )

    data_collator = lambda x: collate_fn_image_and_text(x, processor, config)

    # if accelerator.is_main_process:
    log_gpu_params_by_device("Model", model)

    def preprocess_logits_for_metrics(logits, labels) -> torch.Tensor:
        # For memory reasons we don't want to gather the logits
        # So lets argmax and compare to the labels
        logits = logits[0] # not sure what the second item is [[-90]]
        return logits.argmax(dim=-1) == labels

    trainer = SFTTrainer(
        model=model,
        peft_config=(model_configs.peft_config if not has_adapter else None),
        args=training_args,
        data_collator=data_collator,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=processor.tokenizer,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    if "wandb" in training_args.report_to and trainer.accelerator.is_main_process:
        wandb.init(
            project=project, name=training_args.run_name, config=asdict(training_args)
        )

    trainer.train()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    dotenv.load_dotenv()
    args = parse_accelerate_args(default_project="creator-image-assets")
    main(args.project, args.run_id)