100XZX001 commited on
Commit
6d77d18
·
verified ·
1 Parent(s): 659a9e2

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +1 -8
training.py CHANGED
@@ -81,11 +81,6 @@ def map_to_env(action: AgentAction):
81
  def load_model():
82
  model_name = "microsoft/Phi-3-mini-4k-instruct"
83
 
84
- # Manually fix the config to avoid rope_scaling KeyError
85
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
86
- if not hasattr(config, 'rope_scaling') or config.rope_scaling is None:
87
- config.rope_scaling = {"type": "linear", "factor": 1.0}
88
-
89
  bnb_config = BitsAndBytesConfig(
90
  load_in_4bit=True,
91
  bnb_4bit_compute_dtype=torch.bfloat16,
@@ -95,11 +90,9 @@ def load_model():
95
 
96
  model = AutoModelForCausalLM.from_pretrained(
97
  model_name,
98
- config=config, # use the patched config
99
  quantization_config=bnb_config,
100
  device_map="auto",
101
- trust_remote_code=True,
102
- attn_implementation="eager", # force eager, avoid flash-attn
103
  torch_dtype=torch.bfloat16,
104
  )
105
 
 
81
  def load_model():
82
  model_name = "microsoft/Phi-3-mini-4k-instruct"
83
 
 
 
 
 
 
84
  bnb_config = BitsAndBytesConfig(
85
  load_in_4bit=True,
86
  bnb_4bit_compute_dtype=torch.bfloat16,
 
90
 
91
  model = AutoModelForCausalLM.from_pretrained(
92
  model_name,
 
93
  quantization_config=bnb_config,
94
  device_map="auto",
95
+ attn_implementation="eager", # avoid flash‑attn
 
96
  torch_dtype=torch.bfloat16,
97
  )
98