Spaces:
Paused
Paused
Commit ·
e9dea07
1
Parent(s): 332efeb
Replace flash-attn with PyTorch built-in SDPA (no CUDA compile needed)
Browse files- cloud_arena/llm_training.py +1 -1
- requirements.txt +0 -1
cloud_arena/llm_training.py
CHANGED
|
@@ -183,7 +183,7 @@ def train_llm(model_name="meta-llama/Llama-3.1-8B",
|
|
| 183 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
|
| 184 |
model = AutoModelForCausalLM.from_pretrained(
|
| 185 |
model_name, torch_dtype=torch.bfloat16, token=hf_token,
|
| 186 |
-
attn_implementation="
|
| 187 |
).to(DEVICE)
|
| 188 |
|
| 189 |
lora_config = LoraConfig(
|
|
|
|
| 183 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
|
| 184 |
model = AutoModelForCausalLM.from_pretrained(
|
| 185 |
model_name, torch_dtype=torch.bfloat16, token=hf_token,
|
| 186 |
+
attn_implementation="sdpa", # PyTorch built-in, no flash-attn package needed
|
| 187 |
).to(DEVICE)
|
| 188 |
|
| 189 |
lora_config = LoraConfig(
|
requirements.txt
CHANGED
|
@@ -14,4 +14,3 @@ peft==0.12.0
|
|
| 14 |
accelerate==0.33.0
|
| 15 |
bitsandbytes>=0.43.0
|
| 16 |
sentencepiece
|
| 17 |
-
flash-attn>=2.5.0
|
|
|
|
| 14 |
accelerate==0.33.0
|
| 15 |
bitsandbytes>=0.43.0
|
| 16 |
sentencepiece
|
|
|