stmasson commited on
Commit
517aef6
·
verified ·
1 Parent(s): ddf67ce

Upload scripts/train_n8n_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_n8n_sft.py +1 -2
scripts/train_n8n_sft.py CHANGED
@@ -11,7 +11,6 @@
11
  # "torch>=2.4.0",
12
  # "einops>=0.8.0",
13
  # "sentencepiece>=0.2.0",
14
- # "flash-attn>=2.5.0",
15
  # ]
16
  # [tool.uv]
17
  # index-strategy = "unsafe-best-match"
@@ -118,7 +117,7 @@ else:
118
  model = AutoModelForCausalLM.from_pretrained(
119
  MODEL_NAME,
120
  torch_dtype=torch.bfloat16,
121
- attn_implementation="flash_attention_2",
122
  device_map="auto",
123
  trust_remote_code=True,
124
  )
 
11
  # "torch>=2.4.0",
12
  # "einops>=0.8.0",
13
  # "sentencepiece>=0.2.0",
 
14
  # ]
15
  # [tool.uv]
16
  # index-strategy = "unsafe-best-match"
 
117
  model = AutoModelForCausalLM.from_pretrained(
118
  MODEL_NAME,
119
  torch_dtype=torch.bfloat16,
120
+ attn_implementation="sdpa", # SDPA is built into PyTorch, no extra install needed
121
  device_map="auto",
122
  trust_remote_code=True,
123
  )