kavin57447 commited on
Commit
e9dea07
·
1 Parent(s): 332efeb

Replace flash-attn with PyTorch built-in SDPA (no CUDA compile needed)

Browse files
Files changed (2) hide show
  1. cloud_arena/llm_training.py +1 -1
  2. requirements.txt +0 -1
cloud_arena/llm_training.py CHANGED
@@ -183,7 +183,7 @@ def train_llm(model_name="meta-llama/Llama-3.1-8B",
183
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
184
  model = AutoModelForCausalLM.from_pretrained(
185
  model_name, torch_dtype=torch.bfloat16, token=hf_token,
186
- attn_implementation="flash_attention_2",
187
  ).to(DEVICE)
188
 
189
  lora_config = LoraConfig(
 
183
  tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
184
  model = AutoModelForCausalLM.from_pretrained(
185
  model_name, torch_dtype=torch.bfloat16, token=hf_token,
186
+ attn_implementation="sdpa", # PyTorch built-in, no flash-attn package needed
187
  ).to(DEVICE)
188
 
189
  lora_config = LoraConfig(
requirements.txt CHANGED
@@ -14,4 +14,3 @@ peft==0.12.0
14
  accelerate==0.33.0
15
  bitsandbytes>=0.43.0
16
  sentencepiece
17
- flash-attn>=2.5.0
 
14
  accelerate==0.33.0
15
  bitsandbytes>=0.43.0
16
  sentencepiece