stmasson commited on
Commit
5476247
·
verified ·
1 Parent(s): 285df1b

Upload scripts/train_alizee_v2_stage1_sft.py with huggingface_hub

Browse files
scripts/train_alizee_v2_stage1_sft.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/env python3
2
  # /// script
3
  # dependencies = [
 
4
  # "trl>=0.17.0",
5
  # "peft>=0.14.0",
6
  # "transformers>=4.48.0",
@@ -8,7 +9,6 @@
8
  # "bitsandbytes>=0.45.0",
9
  # "trackio",
10
  # "datasets>=3.0.0",
11
- # "flash-attn>=2.5.0",
12
  # ]
13
  # ///
14
 
@@ -85,7 +85,7 @@ model = AutoModelForCausalLM.from_pretrained(
85
  quantization_config=bnb_config,
86
  device_map="auto",
87
  trust_remote_code=True,
88
- attn_implementation="flash_attention_2",
89
  torch_dtype="auto",
90
  )
91
  model = prepare_model_for_kbit_training(model)
 
1
  #!/usr/bin/env python3
2
  # /// script
3
  # dependencies = [
4
+ # "torch>=2.2.0",
5
  # "trl>=0.17.0",
6
  # "peft>=0.14.0",
7
  # "transformers>=4.48.0",
 
9
  # "bitsandbytes>=0.45.0",
10
  # "trackio",
11
  # "datasets>=3.0.0",
 
12
  # ]
13
  # ///
14
 
 
85
  quantization_config=bnb_config,
86
  device_map="auto",
87
  trust_remote_code=True,
88
+ attn_implementation="sdpa", # Use PyTorch's built-in SDPA
89
  torch_dtype="auto",
90
  )
91
  model = prepare_model_for_kbit_training(model)